Model Improvement in Physics-Informed Neural Networks for solving Inverse problems in ODEs.

Consider an Inverse problem setting for the lotka volterra system. Here we want to optimize parameters $\alpha$, $\beta$, $\gamma$ and $\delta$ and also solve a parametric Lotka Volterra system. PINNs are especially useful in these types of problems and are preferred over conventional solvers, due to their ability to learn from observations - the underlying physics governing the distribution of observations.

We start by defining the problem, with a random and non informative initialization for parameters:

using NeuralPDE, OrdinaryDiffEq, Lux, Random, OptimizationOptimJL, LineSearches,
      Distributions, Plots
using FastGaussQuadrature

function lv(u, p, t)
    u₁, u₂ = u
    α, β, γ, δ = p
    du₁ = α * u₁ - β * u₁ * u₂
    du₂ = δ * u₁ * u₂ - γ * u₂
    [du₁, du₂]
end

tspan = (0.0, 5.0)
u0 = [5.0, 5.0]
initialization = [-5.0, 8.0, 5.0, -7.0]
prob = ODEProblem(lv, u0, tspan, initialization)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: false
Non-trivial mass matrix: false
timespan: (0.0, 5.0)
u0: 2-element Vector{Float64}:
 5.0
 5.0

We require a set of observations before we train the PINN. Considering we want robust results even for cases where measurement values are sparse and limited in number. We simulate a system that uses the true parameter true_p values and record phenomena/solution (u) values algorithmically at only N=20 pre-decided timepoints in the system's time domain.

The value for N can be incremented based on the non linearity (~ N degree polynomial) in the measured phenomenon, this tutorial's setting shows that even with minimal but systematically chosen data-points we can extract excellent results.

true_p = [1.5, 1.0, 3.0, 1.0]
prob_data = remake(prob, p = true_p)

N = 20
x, w = gausslobatto(N)
a = tspan[1]
b = tspan[2]
5.0

Now scale the weights and the gauss-lobatto/clenshaw-curtis/gauss-legendre quadrature points to fit in tspan.

t = map((x) -> (x * (b - a) + (b + a)) / 2, x)
W = map((x) -> x * (b - a) / 2, w)
20-element Vector{Float64}:
 0.013157894736842105
 0.08059280797122265
 0.1429545053189173
 0.20157940999029922
 0.2549787492486272
 0.3017730690716866
 0.34075120589681046
 0.3709038851772921
 0.3914502566186888
 0.4018582159696143
 0.4018582159696143
 0.3914502566186888
 0.3709038851772921
 0.34075120589681046
 0.3017730690716866
 0.2549787492486272
 0.20157940999029922
 0.1429545053189173
 0.08059280797122265
 0.013157894736842105

We now have our dataset of 20 measurements in our tspan and corresponding weights. Using this we can now use the Data Quadrature loss function by passing estim_collocate = true in NNODE.

sol_data = solve(prob_data, Tsit5(); saveat = t)
t_ = sol_data.t
u_ = sol_data.u
u1_ = [u_[i][1] for i in eachindex(t_)]
u2_ = [u_[i][2] for i in eachindex(t_)]
dataset = [u1_, u2_, t_, W]
4-element Vector{Vector{Float64}}:
 [5.0, 4.181603611078826, 2.645397810891132, 1.3912498015703647, 0.8210526163681184, 0.6722098478776918, 0.7611053094812956, 1.0772907174319215, 1.7299869127510388, 2.9470253265117505, 5.049007710901019, 7.796206693042137, 5.5936246049244325, 1.4464064682616247, 0.7465251501572414, 0.6800898950063358, 0.7600175010829118, 0.8812968395300225, 0.9887672013131465, 1.0424843918922717]
 [5.0, 5.396689381837423, 5.620089682747766, 4.668051098875351, 2.9860691512818804, 1.5808704215183293, 0.7539189127225489, 0.3564866998395664, 0.19132741119848423, 0.14380368289069814, 0.20915289174224483, 0.8200396835925662, 4.616390757346223, 4.732955149243002, 2.4829239251518236, 1.3047202433563063, 0.7730246691232083, 0.5302342006306647, 0.42063858049750924, 0.38230716879691223]
 [0.0, 0.04814073776521477, 0.1601637529683364, 0.33280505477512445, 0.5615793476198603, 0.8405589942742218, 1.1625178399202842, 1.5191170407152268, 1.9011207351925337, 2.2986351569029453, 2.7013648430970547, 3.0988792648074663, 3.480882959284773, 3.8374821600797158, 4.159441005725778, 4.43842065238014, 4.6671949452248755, 4.839836247031664, 4.951859262234786, 5.0]
 [0.013157894736842105, 0.08059280797122265, 0.1429545053189173, 0.20157940999029922, 0.2549787492486272, 0.3017730690716866, 0.34075120589681046, 0.3709038851772921, 0.3914502566186888, 0.4018582159696143, 0.4018582159696143, 0.3914502566186888, 0.3709038851772921, 0.34075120589681046, 0.3017730690716866, 0.2549787492486272, 0.20157940999029922, 0.1429545053189173, 0.08059280797122265, 0.013157894736842105]

Now, let's define a neural network for the PINN using Lux.jl.

rng = Random.default_rng()
Random.seed!(rng, 0)
n = 7
chain = Chain(Dense(1, n, tanh), Dense(n, n, tanh), Dense(n, 2))
ps, st = Lux.setup(rng, chain) |> f64
((layer_1 = (weight = [-0.08216114342212677; -0.5444445610046387; … ; 2.0908210277557373; -1.323716640472412;;], bias = [-0.38509929180145264, 0.32322537899017334, -0.32623517513275146, -0.7673453092575073, 0.7302734851837158, -0.7812288999557495, -0.16844713687896729]), layer_2 = (weight = [-0.857661783695221 0.07281388342380524 … -0.22586384415626526 0.7762117981910706; -0.1741035431623459 0.4331861138343811 … 0.36346158385276794 1.0857038497924805; … ; -0.1376570612192154 0.8151324391365051 … 0.21547959744930267 -0.17118430137634277; 0.8868374824523926 -0.2730417251586914 … 1.0372037887573242 0.9000675082206726], bias = [-0.3254864513874054, 0.2545594573020935, 0.3401867151260376, 0.3638579249382019, 0.27597615122795105, -0.07990346103906631, -0.2069253772497177]), layer_3 = (weight = [-0.43004998564720154 0.1675717830657959 … 0.43563809990882874 0.16299796104431152; 0.17348821461200714 -0.5971973538398743 … -0.4388984143733978 0.24340596795082092], bias = [-0.12740696966648102, 0.10087139904499054])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
Note

While solving Inverse problems, when we specify param_estim = true in NNODE or BNNODE, an L2 loss function measuring how the neural network's predictions fit the provided dataset is used internally during Maximum Likelihood Estimation. Therefore, the additional_loss mentioned in the ODE parameter estimation tutorial is not limited to an L2 loss function against data.

We now define the optimizer and NNODE - the ODE solving PINN algorithm, for the old PINN model and the proposed new PINN formulation which uses a Data Quadrature loss. This optimizer and respective algorithms are plugged into the solve calls for comparing results between the new and old PINN models.

opt = LBFGS(linesearch = BackTracking())

alg_old = NNODE(
    chain, opt; strategy = GridTraining(0.01), dataset = dataset, param_estim = true)

alg_new = NNODE(chain, opt; strategy = GridTraining(0.01), param_estim = true,
    dataset = dataset, estim_collocate = true)
NNODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.BackTracking{Float64, Int64}, Returns{Nothing}}, Nothing, Bool, GridTraining{Float64}, Bool, Nothing, Vector{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}}(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 7, tanh), layer_2 = Dense(7 => 7, tanh), layer_3 = Dense(7 => 2)), nothing), Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.BackTracking{Float64, Int64}, Returns{Nothing}}(10, LineSearches.InitialStatic{Float64}
  alpha: Float64 1.0
  scaled: Bool false
, LineSearches.BackTracking{Float64, Int64}
  c_1: Float64 0.0001
  ρ_hi: Float64 0.5
  ρ_lo: Float64 0.1
  iterations: Int64 1000
  order: Int64 3
  maxstep: Float64 Inf
  cache: Nothing nothing
, nothing, Returns{Nothing}(nothing), Optim.Flat(), true), nothing, false, true, GridTraining{Float64}(0.01), true, nothing, [[5.0, 4.181603611078826, 2.645397810891132, 1.3912498015703647, 0.8210526163681184, 0.6722098478776918, 0.7611053094812956, 1.0772907174319215, 1.7299869127510388, 2.9470253265117505, 5.049007710901019, 7.796206693042137, 5.5936246049244325, 1.4464064682616247, 0.7465251501572414, 0.6800898950063358, 0.7600175010829118, 0.8812968395300225, 0.9887672013131465, 1.0424843918922717], [5.0, 5.396689381837423, 5.620089682747766, 4.668051098875351, 2.9860691512818804, 1.5808704215183293, 0.7539189127225489, 0.3564866998395664, 0.19132741119848423, 0.14380368289069814, 0.20915289174224483, 0.8200396835925662, 4.616390757346223, 4.732955149243002, 2.4829239251518236, 1.3047202433563063, 0.7730246691232083, 0.5302342006306647, 0.42063858049750924, 0.38230716879691223], [0.0, 0.04814073776521477, 0.1601637529683364, 0.33280505477512445, 0.5615793476198603, 0.8405589942742218, 1.1625178399202842, 1.5191170407152268, 1.9011207351925337, 2.2986351569029453, 2.7013648430970547, 3.0988792648074663, 3.480882959284773, 3.8374821600797158, 4.159441005725778, 4.43842065238014, 4.6671949452248755, 4.839836247031664, 4.951859262234786, 5.0], [0.013157894736842105, 0.08059280797122265, 0.1429545053189173, 0.20157940999029922, 0.2549787492486272, 0.3017730690716866, 0.34075120589681046, 0.3709038851772921, 0.3914502566186888, 0.4018582159696143, 0.4018582159696143, 0.3914502566186888, 0.3709038851772921, 0.34075120589681046, 0.3017730690716866, 0.2549787492486272, 0.20157940999029922, 0.1429545053189173, 0.08059280797122265, 0.013157894736842105]], true, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}())

Now we have all the pieces to solve the optimization problem.

sol_old = solve(
    prob, alg_old; verbose = true, abstol = 1e-12, maxiters = 5000, saveat = 0.01)

sol_new = solve(
    prob, alg_new; verbose = true, abstol = 1e-12, maxiters = 5000, saveat = 0.01)

sol = solve(prob_data, Tsit5(); saveat = 0.01)
sol_points = hcat(sol.u...)
sol_old_points = hcat(sol_old.u...)
sol_new_points = hcat(sol_new.u...)
2×501 Matrix{Float64}:
 5.0  4.81778  4.64051  4.46815  4.30067  …  0.99151   0.995239  0.99885
 5.0  5.10393  5.1993   5.28594  5.36376     0.352015  0.339185  0.326444

Let's plot the predictions from the PINN models, data used and compare it to the ideal system solution. First the old model.

plot(sol, labels = ["u1" "u2"])
plot!(sol_old, labels = ["u1_pinn_old" "u2_pinn_old"])
scatter!(sol_data, labels = ["u1_data" "u2_data"])
Example block output

Clearly the old model cannot optimize given a realistic, tougher initialization of parameters especially with such limited data. It only seems to work when initial values are close to true_p and we have around 500 points for our tspan, as seen in the ODE parameter estimation tutorial.

Lets move on to the proposed new model...

plot(sol, labels = ["u1" "u2"])
plot!(sol_new, labels = ["u1_pinn_new" "u2_pinn_new"])
scatter!(sol_data, labels = ["u1_data" "u2_data"])
Example block output

We can see that it is a good fit! Now let's examine what the estimated parameters of the equation tell us in both cases. We also test for the following: the old model's estimates have at least one parameter value deviating from it's true value by more than 50% while all the new model's estimates must be within 2% of the true_p values.

sol_old.k.u.p
4-element view(::Vector{Float64}, 87:90) with eltype Float64:
  0.7520099906689697
  0.5005725738602825
 -0.2649352840592679
 -0.32117749700843545
Test Passed

This is nowhere near the true [1.5, 1.0, 3.0, 1.0]. But the new model gives :

sol_new.k.u.p
4-element view(::Vector{Float64}, 87:90) with eltype Float64:
 1.4964284939955685
 0.9964274269956848
 3.0519881715421984
 1.020470905167828
Test Passed

This is indeed very close to the true ODE parameter values [1.5, 1.0, 3.0, 1.0].

Note

This feature for using a Data collocation loss is also available for BPINNs solving Inverse problems in ODEs. Use a dataset of the form as described in this tutorial and set estim_collocate=true and you are good to go.