Parameter Estimation for Stochastic Differential Equations and Ensembles

We can use any DEProblem, which not only includes DAEProblem and DDEProblems, but also stochastic problems. In this case, let's use the generalized maximum likelihood to fit the parameters of an SDE's ensemble evaluation.

Let's use the same Lotka-Volterra equation as before, but this time add noise:

using DifferentialEquations, DiffEqParamEstim, Plots, Optimization, ForwardDiff,
      OptimizationOptimJL

pf_func = function (du, u, p, t)
    du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
    du[2] = -3 * u[2] + u[1] * u[2]
end

u0 = [1.0; 1.0]
tspan = (0.0, 10.0)
p = [1.5, 1.0]
pg_func = function (du, u, p, t)
    du[1] = 1e-6u[1]
    du[2] = 1e-6u[2]
end
prob = SDEProblem(pf_func, pg_func, u0, tspan, p)
sol = solve(prob, SRIW1())
retcode: Success
Interpolation: 1st order linear
t: 565-element Vector{Float64}:
  0.0
  0.011713181719699725
  0.01405581806363967
  0.01669128395057211
  0.0196561830733711
  0.022991694586519968
  0.026744145038812443
  0.030965651797641477
  0.03571484690132414
  0.04105769139296714
  ⋮
  9.844780037516513
  9.865428917160218
  9.886417277175154
  9.907761302111023
  9.9294777222365
  9.951583668974925
  9.974096528129675
  9.99703370149167
 10.0
u: 565-element Vector{Vector{Float64}}:
 [1.0, 1.0]
 [1.006011578758741, 0.9768817248737426]
 [1.0072505042505975, 0.9723305859734771]
 [1.0086589382677271, 0.9672392953886637]
 [1.0102619419014602, 0.9615477798590858]
 [1.0120887579695437, 0.9551902764215154]
 [1.0141735478517928, 0.9480952987097114]
 [1.0165562064088038, 0.9401852583079238]
 [1.0192838254850356, 0.9313766316723611]
 [1.0224121312751608, 0.9215798863017723]
 ⋮
 [0.9579087393472587, 1.2456157101291148]
 [0.9634751745691454, 1.194262018877662]
 [0.9701960919478906, 1.144381126648058]
 [0.9781044121583751, 1.095962998874721]
 [0.9872386411260075, 1.048996340058055]
 [0.9976436188223962, 1.0034709224792018]
 [1.0093696582895433, 0.9593767885330952]
 [1.0224757160983977, 0.9167052019443166]
 [1.0242545617742285, 0.9113460749430589]

Now let's generate a dataset from 10,000 solutions of the SDE

using RecursiveArrayTools # for VectorOfArray
t = collect(range(0, stop = 10, length = 200))
function generate_data(t)
    sol = solve(prob, SRIW1())
    randomized = VectorOfArray([(sol(t[i]) + 0.01randn(2)) for i in 1:length(t)])
    data = convert(Array, randomized)
end
aggregate_data = convert(Array, VectorOfArray([generate_data(t) for i in 1:10000]))
2×200×10000 Array{Float64, 3}:
[:, :, 1] =
 0.992488  1.03225   1.04962   1.10174   …  0.971826  1.01839  1.03705
 0.994988  0.927154  0.829401  0.735556     1.1039    1.00659  0.905103

[:, :, 2] =
 0.992191  1.03247  1.07429   1.07992   …  0.965761  0.987526  1.02086
 0.984332  0.90034  0.826083  0.757433     1.12247   1.00458   0.908923

[:, :, 3] =
 1.00037  1.04072   1.04291  1.09766   …  0.958791  0.97207  1.01547  1.02046
 1.00284  0.888523  0.81083  0.737042     1.22641   1.10587  1.01511  0.91556

;;; … 

[:, :, 9998] =
 0.981002  1.01984   1.06666   1.09034   …  0.982981  0.98129  1.02491
 1.00192   0.910339  0.811732  0.749381     1.14271   1.01089  0.913956

[:, :, 9999] =
 0.980736  1.03031   1.06971   1.1134   …  0.975361  1.00525  1.01784
 0.989885  0.924049  0.824337  0.75345     1.12109   1.00769  0.923774

[:, :, 10000] =
 1.01043  1.02556   1.05665   1.10264   …  0.966769  0.989677  1.0137
 1.02104  0.902009  0.829709  0.724681     1.09997   1.03577   0.925264

Now let's estimate the parameters. Instead of using single runs from the SDE, we will use a EnsembleProblem. This means that it will solve the SDE N times to come up with an approximate probability distribution at each time point and use that in the likelihood estimate.

monte_prob = EnsembleProblem(prob)
EnsembleProblem with problem SDEProblem

We use Optim.jl for optimization below

obj = build_loss_objective(monte_prob, SOSRI(), L2Loss(t, aggregate_data),
                           Optimization.AutoForwardDiff(),
                           maxiters = 10000, verbose = false, trajectories = 1000)
optprob = Optimization.OptimizationProblem(obj, [1.0, 0.5])
result = solve(optprob, Optim.BFGS())
retcode: Success
u: 2-element Vector{Float64}:
 3.6903223696688268
 5.897076705633461

Parameter Estimation in case of SDE's with a regular L2Loss can have poor accuracy due to only fitting against the mean properties as mentioned in First Differencing.

result.original
 * Status: success

 * Candidate solution
    Final objective value:     2.531202e+03

 * Found with
    Algorithm:     BFGS

 * Convergence measures
    |x - x'|               = 5.40e+00 ≰ 0.0e+00
    |x - x'|/|x'|          = 9.15e-01 ≰ 0.0e+00
    |f(x) - f(x')|         = 9.70e+02 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 3.83e-01 ≰ 0.0e+00
    |g(x)|                 = 0.00e+00 ≤ 1.0e-08

 * Work counters
    Seconds run:   152  (vs limit Inf)
    Iterations:    1
    f(x) calls:    5
    ∇f(x) calls:   5

Instead, when we use L2Loss with first differencing enabled, we get much more accurate estimates.

obj = build_loss_objective(monte_prob, SRIW1(),
                           L2Loss(t, aggregate_data, differ_weight = 1.0,
                                  data_weight = 0.5), Optimization.AutoForwardDiff(),
                           verbose = false, trajectories = 1000, maxiters = 1000)
optprob = Optimization.OptimizationProblem(obj, [1.0, 0.5])
result = solve(optprob, Optim.BFGS())
result.original
 * Status: failure (line search failed)

 * Candidate solution
    Final objective value:     1.069168e-01

 * Found with
    Algorithm:     BFGS

 * Convergence measures
    |x - x'|               = 0.00e+00 ≤ 0.0e+00
    |x - x'|/|x'|          = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|         = 6.19e-07 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 5.79e-06 ≰ 0.0e+00
    |g(x)|                 = 1.99e-03 ≰ 1.0e-08

 * Work counters
    Seconds run:   623  (vs limit Inf)
    Iterations:    40
    f(x) calls:    308
    ∇f(x) calls:   308

Here, we see that we successfully recovered the drift parameter, and got close to the original noise parameter after searching a two-orders-of-magnitude range.

println(result.u)
[1.499929514109152, 0.9995541776688742]