Optimizing a Parameterized ODE

Let us fit a parameterized ODE to some data. We will use the Lotka-Volterra model as an example. We will use Single Shooting to fit the parameters.

using OrdinaryDiffEqTsit5, NonlinearSolve, Plots

Let us simulate some real data from the Lotka-Volterra model.

function lotka_volterra!(du, u, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = dx = α * x - β * x * y
    du[2] = dy = -δ * y + γ * x * y
end

# Initial condition
u0 = [1.0, 1.0]

# Simulation interval and intermediary points
tspan = (0.0, 2.0)
tsteps = 0.0:0.1:10.0

# LV equation parameter. p = [α, β, δ, γ]
p = [1.5, 1.0, 3.0, 1.0]

# Setup the ODE problem, then solve
prob = ODEProblem(lotka_volterra!, u0, tspan, p)
sol = solve(prob, Tsit5(); saveat = tsteps)

# Plot the solution
using Plots
plot(sol; linewidth = 3)
Example block output

Let us now formulate the parameter estimation as a Nonlinear Least Squares Problem.

function loss_function(ode_param, data)
    sol = solve(prob, Tsit5(); p = ode_param, saveat = tsteps)
    return vec(reduce(hcat, sol.u)) .- data
end

p_init = zeros(4)

nlls_prob = NonlinearLeastSquaresProblem(loss_function, p_init, vec(reduce(hcat, sol.u)))
NonlinearLeastSquaresProblem with uType Vector{Float64}. In-place: false
u0: 4-element Vector{Float64}:
 0.0
 0.0
 0.0
 0.0

Now, we can use any NLLS solver to solve this problem.

res = solve(nlls_prob, LevenbergMarquardt(); maxiters = 1000, show_trace = Val(true),
    trace_level = TraceWithJacobianConditionNumber(25))

Algorithm: LevenbergMarquardt(
    trustregion = LevenbergMarquardtTrustRegion(
        β_uphill = 1.0
    ),
    descent = GeodesicAcceleration(
        descent = DampedNewtonDescent(
            initial_damping = 1.0,
            damping_fn = LevenbergMarquardtDampingFunction(
                increase_factor = 2.0,
                decrease_factor = 3.0,
                min_damping = 1.0e-8
            )
        ),
        finite_diff_step_geodesic = 0.1,
        α = 0.75
    ),
    autodiff = AutoForwardDiff(),
    vjp_autodiff = AutoFiniteDiff(
        fdtype = Val{:forward}(),
        fdjtype = Val{:forward}(),
        fdhtype = Val{:hcentral}(),
        dir = true
    ),
    jvp_autodiff = AutoForwardDiff(),
    concrete_jac = Val{true}()
)

----    	-------------       	-----------         	-------
Iter    	f(u) 2-norm         	Step 2-norm         	cond(J)
----    	-------------       	-----------         	-------
0       	1.46131516e+01      	0.00000000e+00      	2.99618417e+16
1       	1.46131516e+01      	0.00000000e+00      	2.99618417e+16
26      	2.72610330e-11      	4.35850194e-10      	2.99618417e+16
51      	7.54288070e-13      	0.00000000e+00      	2.99618417e+16
76      	7.17990228e-13      	2.43201875e-16      	2.99618417e+16
101     	6.97021694e-13      	0.00000000e+00      	2.99618417e+16
126     	6.57107897e-13      	0.00000000e+00      	2.99618417e+16
Final   	6.43121349e-13
----------------------
res
retcode: Stalled
u: 4-element Vector{Float64}:
 1.5000000000000473
 0.999999999999845
 2.9999999999991998
 0.9999999999997532

We can also use Trust Region methods.

res = solve(nlls_prob, TrustRegion(); maxiters = 1000, show_trace = Val(true),
    trace_level = TraceWithJacobianConditionNumber(25))

Algorithm: TrustRegion(
    trustregion = GenericTrustRegionScheme(
        method = __Simple(),
        step_threshold = 1//10000,
        shrink_threshold = 1//4,
        shrink_factor = 1//4,
        expand_factor = 2//1,
        expand_threshold = 3//4,
        max_trust_radius = 0//1,
        initial_trust_radius = 0//1
    ),
    descent = Dogleg(
        newton_descent = NewtonDescent(),
        steepest_descent = SteepestDescent()
    ),
    max_shrink_times = 32,
    autodiff = AutoForwardDiff(),
    vjp_autodiff = AutoFiniteDiff(
        fdtype = Val{:forward}(),
        fdjtype = Val{:forward}(),
        fdhtype = Val{:hcentral}(),
        dir = true
    ),
    jvp_autodiff = AutoForwardDiff(),
    concrete_jac = Val{false}()
)

----    	-------------       	-----------         	-------
Iter    	f(u) 2-norm         	Step 2-norm         	cond(J)
----    	-------------       	-----------         	-------
0       	1.46131516e+01      	0.00000000e+00      	2.99618417e+16
1       	1.46131516e+01      	0.00000000e+00      	2.99618417e+16
Final   	1.90451010e-13
----------------------
res
retcode: Success
u: 4-element Vector{Float64}:
 1.4999999999999987
 1.0000000000000313
 3.0000000000001714
 1.0000000000000542

Let's plot the solution.

prob2 = remake(prob; tspan = (0.0, 10.0))
sol_fit = solve(prob2, Tsit5(); p = res.u)
sol_true = solve(prob2, Tsit5(); p = p)
plot(sol_true; linewidth = 3)
plot!(sol_fit; linewidth = 3, linestyle = :dash)
Example block output