Parameter Estimation with Physics-Informed Neural Networks for ODEs

Consider the lotka volterra system

with Physics-Informed Neural Networks. Now we would consider the case where we want to optimize the parameters $\alpha$, $\beta$, $\gamma$ and $\delta$.

We start by defining the problem:

using NeuralPDE, OrdinaryDiffEq, Lux, Random, OptimizationOptimJL, LineSearches, Plots

function lv(u, p, t)
    u₁, u₂ = u
    α, β, γ, δ = p
    du₁ = α * u₁ - β * u₁ * u₂
    du₂ = δ * u₁ * u₂ - γ * u₂
    [du₁, du₂]
end

tspan = (0.0, 5.0)
u0 = [5.0, 5.0]
prob = ODEProblem(lv, u0, tspan, [1.0, 1.0, 1.0, 1.0])
ODEProblem with uType Vector{Float64} and tType Float64. In-place: false
Non-trivial mass matrix: false
timespan: (0.0, 5.0)
u0: 2-element Vector{Float64}:
 5.0
 5.0

As we want to estimate the parameters as well, let's get some data.

true_p = [1.5, 1.0, 3.0, 1.0]
prob_data = remake(prob, p = true_p)
sol_data = solve(prob_data, Tsit5(), saveat = 0.01)
t_ = sol_data.t
u_ = reduce(hcat, sol_data.u)
2×501 Matrix{Float64}:
 5.0  4.82567  4.65308  4.48283  4.31543  …  1.01959   1.03094   1.04248
 5.0  5.09656  5.18597  5.26791  5.34212     0.397663  0.389887  0.382307

Now, let's define a neural network for the PINN using Lux.jl.

rng = Random.default_rng()
Random.seed!(rng, 0)
n = 15
chain = Chain(Dense(1, n, σ), Dense(n, n, σ), Dense(n, n, σ), Dense(n, 2))
ps, st = Lux.setup(rng, chain) |> f64
((layer_1 = (weight = [-0.04929668828845024; -0.3266667425632477; … ; -1.3531280755996704; -0.2917589843273163;;], bias = [0.28568029403686523, -0.4209803342819214, -0.24613642692565918, -0.9429000616073608, -0.3618292808532715, 0.077278733253479, 0.9969245195388794, 0.7939795255661011, 0.45440757274627686, -0.4830443859100342, -0.6861011981964111, -0.3221019506454468, -0.5597391128540039, -0.15051674842834473, 0.9440881013870239]), layer_2 = (weight = [-0.08606008440256119 -0.2168799191713333 … -0.3507671356201172 0.07374405115842819; 0.24009405076503754 -0.2372819483280182 … 0.34944412112236023 -0.21207459270954132; … ; 0.3976286053657532 0.28444960713386536 … -0.32817620038986206 0.396392285823822; -0.07926429808139801 0.35875916481018066 … -0.03593128174543381 -0.28511112928390503], bias = [-0.065037302672863, 0.18384626507759094, 0.17181798815727234, -0.17310386896133423, 0.06428726017475128, 0.09600061178207397, -0.08703552931547165, 0.06890828162431717, -0.16194558143615723, -0.14649711549282074, -0.14649459719657898, -0.04401325806975365, -0.015492657199501991, 0.1046019047498703, 0.15015578269958496]), layer_3 = (weight = [-0.2995997369289398 0.14921274781227112 … -0.011808237060904503 -0.3409591019153595; 0.4351722002029419 0.1286778748035431 … -0.20781198143959045 -0.030425485223531723; … ; -0.02206072397530079 0.14348538219928741 … -0.05763476341962814 -0.2672235071659088; 0.2975636124610901 -0.06781639903783798 … 0.4012162387371063 0.12123444676399231], bias = [0.04135546088218689, -0.2398381233215332, 0.1595604568719864, 0.08355490118265152, -0.06149742379784584, -0.06998120248317719, -0.008059235289692879, -0.10936713218688965, -0.18340998888015747, 0.06297893822193146, 0.04081515222787857, -0.04258332401514053, 0.11171907186508179, -0.21218737959861755, 0.07965957373380661]), layer_4 = (weight = [0.3909371793270111 -0.23473049700260162 … 0.07385867089033127 0.31727129220962524; -0.04396385699510574 0.1817844659090042 … -0.26729491353034973 0.24492913484573364], bias = [0.04966225475072861, -0.04299044609069824])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))

Next we define an additional loss term to in the total loss which measures how the neural network's predictions is fitting the data.

additional_loss(phi, θ) = sum(abs2, phi(t_, θ) .- u_) / size(u_, 2)
additional_loss (generic function with 1 method)

Next we define the optimizer and NNODE which is then plugged into the solve call.

opt = LBFGS(linesearch = BackTracking())
alg = NNODE(chain, opt, ps; strategy = WeightedIntervalTraining([0.7, 0.2, 0.1], 500),
    param_estim = true, additional_loss)
NNODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.BackTracking{Float64, Int64}, Returns{Nothing}}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_2::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_3::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_4::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}}, Bool, WeightedIntervalTraining{Float64}, Bool, typeof(Main.additional_loss), Vector{Any}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}}(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 15, σ), layer_2 = Dense(15 => 15, σ), layer_3 = Dense(15 => 15, σ), layer_4 = Dense(15 => 2)), nothing), Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.BackTracking{Float64, Int64}, Returns{Nothing}}(10, LineSearches.InitialStatic{Float64}
  alpha: Float64 1.0
  scaled: Bool false
, LineSearches.BackTracking{Float64, Int64}
  c_1: Float64 0.0001
  ρ_hi: Float64 0.5
  ρ_lo: Float64 0.1
  iterations: Int64 1000
  order: Int64 3
  maxstep: Float64 Inf
  cache: Nothing nothing
, nothing, Returns{Nothing}(nothing), Optim.Flat(), true), (layer_1 = (weight = [-0.04929668828845024; -0.3266667425632477; … ; -1.3531280755996704; -0.2917589843273163;;], bias = [0.28568029403686523, -0.4209803342819214, -0.24613642692565918, -0.9429000616073608, -0.3618292808532715, 0.077278733253479, 0.9969245195388794, 0.7939795255661011, 0.45440757274627686, -0.4830443859100342, -0.6861011981964111, -0.3221019506454468, -0.5597391128540039, -0.15051674842834473, 0.9440881013870239]), layer_2 = (weight = [-0.08606008440256119 -0.2168799191713333 … -0.3507671356201172 0.07374405115842819; 0.24009405076503754 -0.2372819483280182 … 0.34944412112236023 -0.21207459270954132; … ; 0.3976286053657532 0.28444960713386536 … -0.32817620038986206 0.396392285823822; -0.07926429808139801 0.35875916481018066 … -0.03593128174543381 -0.28511112928390503], bias = [-0.065037302672863, 0.18384626507759094, 0.17181798815727234, -0.17310386896133423, 0.06428726017475128, 0.09600061178207397, -0.08703552931547165, 0.06890828162431717, -0.16194558143615723, -0.14649711549282074, -0.14649459719657898, -0.04401325806975365, -0.015492657199501991, 0.1046019047498703, 0.15015578269958496]), layer_3 = (weight = [-0.2995997369289398 0.14921274781227112 … -0.011808237060904503 -0.3409591019153595; 0.4351722002029419 0.1286778748035431 … -0.20781198143959045 -0.030425485223531723; … ; -0.02206072397530079 0.14348538219928741 … -0.05763476341962814 -0.2672235071659088; 0.2975636124610901 -0.06781639903783798 … 0.4012162387371063 0.12123444676399231], bias = [0.04135546088218689, -0.2398381233215332, 0.1595604568719864, 0.08355490118265152, -0.06149742379784584, -0.06998120248317719, -0.008059235289692879, -0.10936713218688965, -0.18340998888015747, 0.06297893822193146, 0.04081515222787857, -0.04258332401514053, 0.11171907186508179, -0.21218737959861755, 0.07965957373380661]), layer_4 = (weight = [0.3909371793270111 -0.23473049700260162 … 0.07385867089033127 0.31727129220962524; -0.04396385699510574 0.1817844659090042 … -0.26729491353034973 0.24492913484573364], bias = [0.04966225475072861, -0.04299044609069824])), false, true, WeightedIntervalTraining{Float64}([0.7, 0.2, 0.1], 500), true, Main.additional_loss, Any[], false, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}())

Now we have all the pieces to solve the optimization problem.

sol = solve(prob, alg, verbose = true, abstol = 1e-8, maxiters = 5000, saveat = t_)
Test Passed

Let's plot the predictions from the PINN and compare it to the data.

plot(sol, labels = ["u1_pinn" "u2_pinn"])
plot!(sol_data, labels = ["u1_data" "u2_data"])
Example block output

We can see it is a good fit! Now let's see if we have the parameters of the equation also estimated correctly or not.

sol.k.u.p
4-element view(::Vector{Float64}, 543:546) with eltype Float64:
 1.5141435131861452
 1.007987608456337
 2.9797759879902177
 0.994423206907477

We can see it is indeed close to the true values [1.5, 1.0, 3.0, 1.0].