Optimizing a Parameterized ODE

Let us fit a parameterized ODE to some data. We will use the Lotka-Volterra model as an example. We will use Single Shooting to fit the parameters.

import OrdinaryDiffEqTsit5 as ODE
import NonlinearSolve as NLS
import Plots

Let us simulate some real data from the Lotka-Volterra model.

function lotka_volterra!(du, u, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = dx = α * x - β * x * y
    du[2] = dy = -δ * y + γ * x * y
end

# Initial condition
u0 = [1.0, 1.0]

# Simulation interval and intermediary points
tspan = (0.0, 2.0)
tsteps = 0.0:0.1:10.0

# LV equation parameter. p = [α, β, δ, γ]
p = [1.5, 1.0, 3.0, 1.0]

# Setup the ODE problem, then solve
prob = ODE.ODEProblem(lotka_volterra!, u0, tspan, p)
sol = ODE.solve(prob, ODE.Tsit5(); saveat = tsteps)

# Plot the solution
Plots.plot(sol; linewidth = 3)
Example block output

Let us now formulate the parameter estimation as a Nonlinear Least Squares Problem.

function loss_function(ode_param, data)
    sol = ODE.solve(prob, ODE.Tsit5(); p = ode_param, saveat = tsteps)
    return vec(reduce(hcat, sol.u)) .- data
end

p_init = zeros(4)

nlls_prob = NLS.NonlinearLeastSquaresProblem(loss_function, p_init, vec(reduce(hcat, sol.u)))
NonlinearLeastSquaresProblem with uType Vector{Float64}. In-place: false
u0: 4-element Vector{Float64}:
 0.0
 0.0
 0.0
 0.0

Now, we can use any NLLS solver to solve this problem.

res = NLS.solve(nlls_prob, NLS.LevenbergMarquardt(); maxiters = 1000, show_trace = Val(true),
    trace_level = NLS.TraceWithJacobianConditionNumber(25))

Algorithm: LevenbergMarquardt(
    trustregion = LevenbergMarquardtTrustRegion(
        β_uphill = 1.0
    ),
    descent = GeodesicAcceleration(
        descent = DampedNewtonDescent(
            initial_damping = 1.0,
            damping_fn = LevenbergMarquardtDampingFunction(
                increase_factor = 2.0,
                decrease_factor = 3.0,
                min_damping = 1.0e-8
            )
        ),
        finite_diff_step_geodesic = 0.1,
        α = 0.75
    ),
    autodiff = AutoForwardDiff(),
    vjp_autodiff = AutoFiniteDiff(
        fdtype = Val{:forward}(),
        fdjtype = Val{:forward}(),
        fdhtype = Val{:hcentral}(),
        dir = true
    ),
    jvp_autodiff = AutoForwardDiff(),
    concrete_jac = Val{true}()
)

----    	-------------       	-----------         	-------
Iter    	f(u) 2-norm         	Step 2-norm         	cond(J)
----    	-------------       	-----------         	-------
0       	1.46131516e+01      	0.00000000e+00      	2.99618417e+16
1       	1.46131516e+01      	0.00000000e+00      	2.99618417e+16
26      	2.72610330e-11      	4.35850194e-10      	2.99618417e+16
51      	7.54288070e-13      	0.00000000e+00      	2.99618417e+16
76      	7.17990228e-13      	2.43201875e-16      	2.99618417e+16
101     	6.97021694e-13      	0.00000000e+00      	2.99618417e+16
126     	6.57107897e-13      	0.00000000e+00      	2.99618417e+16
Final   	6.43121349e-13
----------------------
res
retcode: StalledSuccess
u: 4-element Vector{Float64}:
 1.5000000000000473
 0.999999999999845
 2.9999999999991998
 0.9999999999997532

We can also use Trust Region methods.

res = NLS.solve(nlls_prob, NLS.TrustRegion(); maxiters = 1000, show_trace = Val(true),
    trace_level = NLS.TraceWithJacobianConditionNumber(25))

Algorithm: TrustRegion(
    trustregion = GenericTrustRegionScheme(
        method = __Simple(),
        step_threshold = 1//10000,
        shrink_threshold = 1//4,
        shrink_factor = 1//4,
        expand_factor = 2//1,
        expand_threshold = 3//4,
        max_trust_radius = 0//1,
        initial_trust_radius = 0//1
    ),
    descent = Dogleg(
        newton_descent = NewtonDescent(),
        steepest_descent = SteepestDescent()
    ),
    max_shrink_times = 32,
    autodiff = AutoForwardDiff(),
    vjp_autodiff = AutoFiniteDiff(
        fdtype = Val{:forward}(),
        fdjtype = Val{:forward}(),
        fdhtype = Val{:hcentral}(),
        dir = true
    ),
    jvp_autodiff = AutoForwardDiff(),
    concrete_jac = Val{false}()
)

----    	-------------       	-----------         	-------
Iter    	f(u) 2-norm         	Step 2-norm         	cond(J)
----    	-------------       	-----------         	-------
0       	1.46131516e+01      	0.00000000e+00      	2.99618417e+16
1       	1.46131516e+01      	0.00000000e+00      	2.99618417e+16
Final   	1.90451010e-13
----------------------
res
retcode: Success
u: 4-element Vector{Float64}:
 1.4999999999999987
 1.0000000000000313
 3.0000000000001714
 1.0000000000000542

Let's plot the solution.

prob2 = ODE.remake(prob; tspan = (0.0, 10.0))
sol_fit = ODE.solve(prob2, ODE.Tsit5(); p = res.u)
sol_true = ODE.solve(prob2, ODE.Tsit5(); p = p)
Plots.plot(sol_true; linewidth = 3)
Plots.plot!(sol_fit; linewidth = 3, linestyle = :dash)
Example block output