Model Improvement in Physics-Informed Neural Networks for solving Inverse problems in ODEs.
Consider an Inverse problem setting for the lotka volterra system. Here we want to optimize parameters $\alpha$, $\beta$, $\gamma$ and $\delta$ and also solve a parametric Lotka Volterra system. PINNs are especially useful in these types of problems and are preferred over conventional solvers, due to their ability to learn from observations - the underlying physics governing the distribution of observations.
We start by defining the problem, with a random and non informative initialization for parameters:
using NeuralPDE, OrdinaryDiffEq, Lux, Random, OptimizationOptimJL, LineSearches,
Distributions, Plots
using FastGaussQuadrature
function lv(u, p, t)
u₁, u₂ = u
α, β, γ, δ = p
du₁ = α * u₁ - β * u₁ * u₂
du₂ = δ * u₁ * u₂ - γ * u₂
[du₁, du₂]
end
tspan = (0.0, 5.0)
u0 = [5.0, 5.0]
initialization = [-5.0, 8.0, 5.0, -7.0]
prob = ODEProblem(lv, u0, tspan, initialization)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: false
Non-trivial mass matrix: false
timespan: (0.0, 5.0)
u0: 2-element Vector{Float64}:
5.0
5.0
We require a set of observations before we train the PINN. Considering we want robust results even for cases where measurement values are sparse and limited in number. We simulate a system that uses the true parameter true_p
values and record phenomena/solution (u
) values algorithmically at only N=20
pre-decided timepoints in the system's time domain.
The value for N
can be incremented based on the non linearity (~ N
degree polynomial) in the measured phenomenon, this tutorial's setting shows that even with minimal but systematically chosen data-points we can extract excellent results.
true_p = [1.5, 1.0, 3.0, 1.0]
prob_data = remake(prob, p = true_p)
N = 20
x, w = gausslobatto(N)
a = tspan[1]
b = tspan[2]
5.0
Now scale the weights and the gauss-lobatto/clenshaw-curtis/gauss-legendre quadrature points to fit in tspan
.
t = map((x) -> (x * (b - a) + (b + a)) / 2, x)
W = map((x) -> x * (b - a) / 2, w)
20-element Vector{Float64}:
0.013157894736842105
0.08059280797122265
0.1429545053189173
0.20157940999029922
0.2549787492486272
0.3017730690716866
0.34075120589681046
0.3709038851772921
0.3914502566186888
0.4018582159696143
0.4018582159696143
0.3914502566186888
0.3709038851772921
0.34075120589681046
0.3017730690716866
0.2549787492486272
0.20157940999029922
0.1429545053189173
0.08059280797122265
0.013157894736842105
We now have our dataset of 20
measurements in our tspan
and corresponding weights. Using this we can now use the Data Quadrature loss function by passing estim_collocate
= true
in NNODE
.
sol_data = solve(prob_data, Tsit5(); saveat = t)
t_ = sol_data.t
u_ = sol_data.u
u1_ = [u_[i][1] for i in eachindex(t_)]
u2_ = [u_[i][2] for i in eachindex(t_)]
dataset = [u1_, u2_, t_, W]
4-element Vector{Vector{Float64}}:
[5.0, 4.181603611078826, 2.645397810891132, 1.3912498015703647, 0.8210526163681184, 0.6722098478776918, 0.7611053094812956, 1.0772907174319215, 1.7299869127510388, 2.9470253265117505, 5.049007710901019, 7.796206693042137, 5.5936246049244325, 1.4464064682616247, 0.7465251501572414, 0.6800898950063358, 0.7600175010829118, 0.8812968395300225, 0.9887672013131465, 1.0424843918922717]
[5.0, 5.396689381837423, 5.620089682747766, 4.668051098875351, 2.9860691512818804, 1.5808704215183293, 0.7539189127225489, 0.3564866998395664, 0.19132741119848423, 0.14380368289069814, 0.20915289174224483, 0.8200396835925662, 4.616390757346223, 4.732955149243002, 2.4829239251518236, 1.3047202433563063, 0.7730246691232083, 0.5302342006306647, 0.42063858049750924, 0.38230716879691223]
[0.0, 0.04814073776521477, 0.1601637529683364, 0.33280505477512445, 0.5615793476198603, 0.8405589942742218, 1.1625178399202842, 1.5191170407152268, 1.9011207351925337, 2.2986351569029453, 2.7013648430970547, 3.0988792648074663, 3.480882959284773, 3.8374821600797158, 4.159441005725778, 4.43842065238014, 4.6671949452248755, 4.839836247031664, 4.951859262234786, 5.0]
[0.013157894736842105, 0.08059280797122265, 0.1429545053189173, 0.20157940999029922, 0.2549787492486272, 0.3017730690716866, 0.34075120589681046, 0.3709038851772921, 0.3914502566186888, 0.4018582159696143, 0.4018582159696143, 0.3914502566186888, 0.3709038851772921, 0.34075120589681046, 0.3017730690716866, 0.2549787492486272, 0.20157940999029922, 0.1429545053189173, 0.08059280797122265, 0.013157894736842105]
Now, let's define a neural network for the PINN using Lux.jl.
rng = Random.default_rng()
Random.seed!(rng, 0)
n = 7
chain = Chain(Dense(1, n, tanh), Dense(n, n, tanh), Dense(n, 2))
ps, st = Lux.setup(rng, chain) |> f64
((layer_1 = (weight = [-0.08216114342212677; -0.5444445610046387; … ; 2.0908210277557373; -1.323716640472412;;], bias = [-0.38509929180145264, 0.32322537899017334, -0.32623517513275146, -0.7673453092575073, 0.7302734851837158, -0.7812288999557495, -0.16844713687896729]), layer_2 = (weight = [-0.857661783695221 0.07281388342380524 … -0.22586384415626526 0.7762117981910706; -0.1741035431623459 0.4331861138343811 … 0.36346158385276794 1.0857038497924805; … ; -0.1376570612192154 0.8151324391365051 … 0.21547959744930267 -0.17118430137634277; 0.8868374824523926 -0.2730417251586914 … 1.0372037887573242 0.9000675082206726], bias = [-0.3254864513874054, 0.2545594573020935, 0.3401867151260376, 0.3638579249382019, 0.27597615122795105, -0.07990346103906631, -0.2069253772497177]), layer_3 = (weight = [-0.43004998564720154 0.1675717830657959 … 0.43563809990882874 0.16299796104431152; 0.17348821461200714 -0.5971973538398743 … -0.4388984143733978 0.24340596795082092], bias = [-0.12740696966648102, 0.10087139904499054])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
While solving Inverse problems, when we specify param_estim = true
in NNODE
or BNNODE
, an L2 loss function measuring how the neural network's predictions fit the provided dataset
is used internally during Maximum Likelihood Estimation. Therefore, the additional_loss
mentioned in the ODE parameter estimation tutorial is not limited to an L2 loss function against data.
We now define the optimizer and NNODE
- the ODE solving PINN algorithm, for the old PINN model and the proposed new PINN formulation which uses a Data Quadrature loss. This optimizer and respective algorithms are plugged into the solve
calls for comparing results between the new and old PINN models.
opt = LBFGS(linesearch = BackTracking())
alg_old = NNODE(
chain, opt; strategy = GridTraining(0.01), dataset = dataset, param_estim = true)
alg_new = NNODE(chain, opt; strategy = GridTraining(0.01), param_estim = true,
dataset = dataset, estim_collocate = true)
NNODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.BackTracking{Float64, Int64}, Returns{Nothing}}, Nothing, Bool, GridTraining{Float64}, Bool, Nothing, Vector{Vector{Float64}}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}}(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 7, tanh), layer_2 = Dense(7 => 7, tanh), layer_3 = Dense(7 => 2)), nothing), Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.BackTracking{Float64, Int64}, Returns{Nothing}}(10, LineSearches.InitialStatic{Float64}
alpha: Float64 1.0
scaled: Bool false
, LineSearches.BackTracking{Float64, Int64}
c_1: Float64 0.0001
ρ_hi: Float64 0.5
ρ_lo: Float64 0.1
iterations: Int64 1000
order: Int64 3
maxstep: Float64 Inf
cache: Nothing nothing
, nothing, Returns{Nothing}(nothing), Optim.Flat(), true), nothing, false, true, GridTraining{Float64}(0.01), true, nothing, [[5.0, 4.181603611078826, 2.645397810891132, 1.3912498015703647, 0.8210526163681184, 0.6722098478776918, 0.7611053094812956, 1.0772907174319215, 1.7299869127510388, 2.9470253265117505, 5.049007710901019, 7.796206693042137, 5.5936246049244325, 1.4464064682616247, 0.7465251501572414, 0.6800898950063358, 0.7600175010829118, 0.8812968395300225, 0.9887672013131465, 1.0424843918922717], [5.0, 5.396689381837423, 5.620089682747766, 4.668051098875351, 2.9860691512818804, 1.5808704215183293, 0.7539189127225489, 0.3564866998395664, 0.19132741119848423, 0.14380368289069814, 0.20915289174224483, 0.8200396835925662, 4.616390757346223, 4.732955149243002, 2.4829239251518236, 1.3047202433563063, 0.7730246691232083, 0.5302342006306647, 0.42063858049750924, 0.38230716879691223], [0.0, 0.04814073776521477, 0.1601637529683364, 0.33280505477512445, 0.5615793476198603, 0.8405589942742218, 1.1625178399202842, 1.5191170407152268, 1.9011207351925337, 2.2986351569029453, 2.7013648430970547, 3.0988792648074663, 3.480882959284773, 3.8374821600797158, 4.159441005725778, 4.43842065238014, 4.6671949452248755, 4.839836247031664, 4.951859262234786, 5.0], [0.013157894736842105, 0.08059280797122265, 0.1429545053189173, 0.20157940999029922, 0.2549787492486272, 0.3017730690716866, 0.34075120589681046, 0.3709038851772921, 0.3914502566186888, 0.4018582159696143, 0.4018582159696143, 0.3914502566186888, 0.3709038851772921, 0.34075120589681046, 0.3017730690716866, 0.2549787492486272, 0.20157940999029922, 0.1429545053189173, 0.08059280797122265, 0.013157894736842105]], true, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}())
Now we have all the pieces to solve the optimization problem.
sol_old = solve(
prob, alg_old; verbose = true, abstol = 1e-12, maxiters = 5000, saveat = 0.01)
sol_new = solve(
prob, alg_new; verbose = true, abstol = 1e-12, maxiters = 5000, saveat = 0.01)
sol = solve(prob_data, Tsit5(); saveat = 0.01)
sol_points = hcat(sol.u...)
sol_old_points = hcat(sol_old.u...)
sol_new_points = hcat(sol_new.u...)
2×501 Matrix{Float64}:
5.0 4.81778 4.64051 4.46815 4.30067 … 0.99151 0.995239 0.99885
5.0 5.10393 5.1993 5.28594 5.36376 0.352015 0.339185 0.326444
Let's plot the predictions from the PINN models, data used and compare it to the ideal system solution. First the old model.
plot(sol, labels = ["u1" "u2"])
plot!(sol_old, labels = ["u1_pinn_old" "u2_pinn_old"])
scatter!(sol_data, labels = ["u1_data" "u2_data"])
Clearly the old model cannot optimize given a realistic, tougher initialization of parameters especially with such limited data. It only seems to work when initial values are close to true_p
and we have around 500
points for our tspan
, as seen in the ODE parameter estimation tutorial.
Lets move on to the proposed new model...
plot(sol, labels = ["u1" "u2"])
plot!(sol_new, labels = ["u1_pinn_new" "u2_pinn_new"])
scatter!(sol_data, labels = ["u1_data" "u2_data"])
We can see that it is a good fit! Now let's examine what the estimated parameters of the equation tell us in both cases. We also test for the following: the old model's estimates have at least one parameter value deviating from it's true value by more than 50%
while all the new model's estimates must be within 2%
of the true_p
values.
sol_old.k.u.p
4-element view(::Vector{Float64}, 87:90) with eltype Float64:
0.7520099906689697
0.5005725738602825
-0.2649352840592679
-0.32117749700843545
Test Passed
This is nowhere near the true [1.5, 1.0, 3.0, 1.0]. But the new model gives :
sol_new.k.u.p
4-element view(::Vector{Float64}, 87:90) with eltype Float64:
1.4964284939955685
0.9964274269956848
3.0519881715421984
1.020470905167828
Test Passed
This is indeed very close to the true ODE parameter values [1.5, 1.0, 3.0, 1.0].