Optimizing a Parameterized ODE
Let us fit a parameterized ODE to some data. We will use the Lotka-Volterra model as an example. We will use Single Shooting to fit the parameters.
import OrdinaryDiffEqTsit5 as ODE
import NonlinearSolve as NLS
import Plots
Let us simulate some real data from the Lotka-Volterra model.
function lotka_volterra!(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α * x - β * x * y
du[2] = dy = -δ * y + γ * x * y
end
# Initial condition
u0 = [1.0, 1.0]
# Simulation interval and intermediary points
tspan = (0.0, 2.0)
tsteps = 0.0:0.1:10.0
# LV equation parameter. p = [α, β, δ, γ]
p = [1.5, 1.0, 3.0, 1.0]
# Setup the ODE problem, then solve
prob = ODE.ODEProblem(lotka_volterra!, u0, tspan, p)
sol = ODE.solve(prob, ODE.Tsit5(); saveat = tsteps)
# Plot the solution
Plots.plot(sol; linewidth = 3)
Let us now formulate the parameter estimation as a Nonlinear Least Squares Problem.
function loss_function(ode_param, data)
sol = ODE.solve(prob, ODE.Tsit5(); p = ode_param, saveat = tsteps)
return vec(reduce(hcat, sol.u)) .- data
end
p_init = zeros(4)
nlls_prob = NLS.NonlinearLeastSquaresProblem(loss_function, p_init, vec(reduce(hcat, sol.u)))
NonlinearLeastSquaresProblem with uType Vector{Float64}. In-place: false
u0: 4-element Vector{Float64}:
0.0
0.0
0.0
0.0
Now, we can use any NLLS solver to solve this problem.
res = NLS.solve(nlls_prob, NLS.LevenbergMarquardt(); maxiters = 1000, show_trace = Val(true),
trace_level = NLS.TraceWithJacobianConditionNumber(25))
Algorithm: LevenbergMarquardt(
trustregion = LevenbergMarquardtTrustRegion(
β_uphill = 1.0
),
descent = GeodesicAcceleration(
descent = DampedNewtonDescent(
initial_damping = 1.0,
damping_fn = LevenbergMarquardtDampingFunction(
increase_factor = 2.0,
decrease_factor = 3.0,
min_damping = 1.0e-8
)
),
finite_diff_step_geodesic = 0.1,
α = 0.75
),
autodiff = AutoForwardDiff(),
vjp_autodiff = AutoFiniteDiff(
fdtype = Val{:forward}(),
fdjtype = Val{:forward}(),
fdhtype = Val{:hcentral}(),
dir = true
),
jvp_autodiff = AutoForwardDiff(),
concrete_jac = Val{true}()
)
---- ------------- ----------- -------
Iter f(u) 2-norm Step 2-norm cond(J)
---- ------------- ----------- -------
0 1.46131516e+01 0.00000000e+00 2.99618417e+16
1 1.46131516e+01 0.00000000e+00 2.99618417e+16
26 2.72610330e-11 4.35850194e-10 2.99618417e+16
51 7.54288070e-13 0.00000000e+00 2.99618417e+16
76 7.17990228e-13 2.43201875e-16 2.99618417e+16
101 6.97021694e-13 0.00000000e+00 2.99618417e+16
126 6.57107897e-13 0.00000000e+00 2.99618417e+16
Final 6.43121349e-13
----------------------
res
retcode: StalledSuccess
u: 4-element Vector{Float64}:
1.5000000000000473
0.999999999999845
2.9999999999991998
0.9999999999997532
We can also use Trust Region methods.
res = NLS.solve(nlls_prob, NLS.TrustRegion(); maxiters = 1000, show_trace = Val(true),
trace_level = NLS.TraceWithJacobianConditionNumber(25))
Algorithm: TrustRegion(
trustregion = GenericTrustRegionScheme(
method = __Simple(),
step_threshold = 1//10000,
shrink_threshold = 1//4,
shrink_factor = 1//4,
expand_factor = 2//1,
expand_threshold = 3//4,
max_trust_radius = 0//1,
initial_trust_radius = 0//1
),
descent = Dogleg(
newton_descent = NewtonDescent(),
steepest_descent = SteepestDescent()
),
max_shrink_times = 32,
autodiff = AutoForwardDiff(),
vjp_autodiff = AutoFiniteDiff(
fdtype = Val{:forward}(),
fdjtype = Val{:forward}(),
fdhtype = Val{:hcentral}(),
dir = true
),
jvp_autodiff = AutoForwardDiff(),
concrete_jac = Val{false}()
)
---- ------------- ----------- -------
Iter f(u) 2-norm Step 2-norm cond(J)
---- ------------- ----------- -------
0 1.46131516e+01 0.00000000e+00 2.99618417e+16
1 1.46131516e+01 0.00000000e+00 2.99618417e+16
Final 1.90451010e-13
----------------------
res
retcode: Success
u: 4-element Vector{Float64}:
1.4999999999999987
1.0000000000000313
3.0000000000001714
1.0000000000000542
Let's plot the solution.
prob2 = ODE.remake(prob; tspan = (0.0, 10.0))
sol_fit = ODE.solve(prob2, ODE.Tsit5(); p = res.u)
sol_true = ODE.solve(prob2, ODE.Tsit5(); p = p)
Plots.plot(sol_true; linewidth = 3)
Plots.plot!(sol_fit; linewidth = 3, linestyle = :dash)