Fitting Ensembles of ODE Models to Data

In this tutorial, we will showcase how to fit multiple models simultaneously to respective data sources. Let's dive right in!

Formulating the Ensemble Model

First, you want to create a problem which solves multiple problems at the same time. This is the EnsembleProblem. When the parameter estimation tools say it will take any DEProblem, it really means ANY DEProblem, which includes EnsembleProblem.

So, let's get an EnsembleProblem setup that solves with 10 different initial conditions. This looks as follows:

using DifferentialEquations, DiffEqParamEstim, Plots, Optimization, ForwardDiff,
      OptimizationOptimJL

# Monte Carlo Problem Set Up for solving set of ODEs with different initial conditions

# Set up Lotka-Volterra system
function pf_func(du, u, p, t)
    du[1] = p[1] * u[1] - p[2] * u[1] * u[2]
    du[2] = -3 * u[2] + u[1] * u[2]
end
p = [1.5, 1.0]
prob = ODEProblem(pf_func, [1.0, 1.0], (0.0, 10.0), p)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 10.0)
u0: 2-element Vector{Float64}:
 1.0
 1.0

Now for an EnsembleProblem we have to take this problem and tell it what to do N times via the prob_func. So let's generate N=10 different initial conditions, and tell it to run the same problem but with these 10 different initial conditions each time:

# Setting up to solve the problem N times (for the N different initial conditions)
N = 10;
initial_conditions = [
    [1.0, 1.0],
    [1.0, 1.5],
    [1.5, 1.0],
    [1.5, 1.5],
    [0.5, 1.0],
    [1.0, 0.5],
    [0.5, 0.5],
    [2.0, 1.0],
    [1.0, 2.0],
    [2.0, 2.0],
]
function prob_func(prob, i, repeat)
    ODEProblem(prob.f, initial_conditions[i], prob.tspan, prob.p)
end
enprob = EnsembleProblem(prob, prob_func = prob_func)
EnsembleProblem with problem ODEProblem

We can check this does what we want by solving it:

# Check above does what we want
sim = solve(enprob, Tsit5(), trajectories = N)
plot(sim)
Example block output

trajectories=N means “run N times”, and each time it runs the problem returned by the prob_func, which is always the same problem but with the ith initial condition.

Now let's generate a dataset from that. Let's get data points at every t=0.1 using saveat, and then convert the solution into an array.

# Generate a dataset from these runs
data_times = 0.0:0.1:10.0
sim = solve(enprob, Tsit5(), trajectories = N, saveat = data_times)
data = Array(sim)
2×101×10 Array{Float64, 3}:
[:, :, 1] =
 1.0  1.06108   1.14403   1.24917   1.37764   …  0.956979  0.983561  1.03376
 1.0  0.821084  0.679053  0.566893  0.478813     1.35559   1.10629   0.90637

[:, :, 2] =
 1.0  1.01413  1.05394  1.11711   …  1.05324  1.01309  1.00811  1.03162
 1.5  1.22868  1.00919  0.833191     2.08023  1.70818  1.39972  1.14802

[:, :, 3] =
 1.5  1.58801   1.70188   1.84193   2.00901   …  2.0153    2.21084   2.43589
 1.0  0.864317  0.754624  0.667265  0.599149     0.600943  0.549793  0.51368

[:, :, 4] =
 1.5  1.51612  1.5621   1.63555   1.73531   …  1.83822   1.98545   2.15958
 1.5  1.29176  1.11592  0.969809  0.850159     0.771088  0.691421  0.630025

[:, :, 5] =
 0.5  0.531705  0.576474  0.634384  0.706139  …  9.05366   9.4006   8.8391
 1.0  0.77995   0.610654  0.480565  0.380645     0.809383  1.51708  2.82619

[:, :, 6] =
 1.0  1.11027   1.24238   1.39866   1.58195   …  0.753107  0.748814  0.768284
 0.5  0.411557  0.342883  0.289812  0.249142     1.73879   1.38829   1.10932

[:, :, 7] =
 0.5  0.555757  0.623692  0.705084  0.80158   …  8.11213   9.10669   9.92169
 0.5  0.390449  0.30679   0.24286   0.193966     0.261294  0.455928  0.878792

[:, :, 8] =
 2.0  2.11239   2.24921   2.41003   2.59433   …  3.22293   3.47356   3.73011
 1.0  0.909749  0.838025  0.783532  0.745339     0.739406  0.765525  0.813005

[:, :, 9] =
 1.0  0.969326  0.971358  1.00017  …  1.25065  1.1012   1.01733  0.979304
 2.0  1.63445   1.33389   1.09031     3.02672  2.52063  2.07503  1.69808

[:, :, 10] =
 2.0  1.92148  1.88215  1.87711  1.90264  …  2.15079  2.27937   2.43105
 2.0  1.80195  1.61405  1.4426   1.2907      0.95722  0.884825  0.829478

Here, data[i,j,k] is the same as sim[i,j,k] which is the same as sim[k][i,j] (where sim[k] is the kth solution). So data[i,j,k] is the jth timepoint of the ith variable in the kth trajectory.

Now let's build a loss function. A loss function is some loss(sol) that spits out a scalar for how far from optimal we are. In the documentation, I show that we normally do loss = L2Loss(t,data), but we can bootstrap off of this. Instead, let's build an array of N loss functions, each one with the correct piece of data.

# Building a loss function
losses = [L2Loss(data_times, data[:, :, i]) for i in 1:N]
10-element Vector{L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}}:
 L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}(0.0:0.1:10.0, [1.0 1.0610780673356452 … 0.9835609063200819 1.0337581256020802; 1.0 0.8210842775886171 … 1.1062868199419744 0.9063703842885995], nothing, nothing, nothing, nothing)
 L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}(0.0:0.1:10.0, [1.0 1.0141312263417834 … 1.008106024273487 1.0316172494491975; 1.5 1.2286831520665753 … 1.3997241937770144 1.148024473932838], nothing, nothing, nothing, nothing)
 L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}(0.0:0.1:10.0, [1.5 1.5880106683980333 … 2.2108390798809867 2.4358900077551477; 1.0 0.8643172923598124 … 0.549793448237966 0.5136795156288518], nothing, nothing, nothing, nothing)
 L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}(0.0:0.1:10.0, [1.5 1.516120535341813 … 1.9854481733983935 2.1595824002587736; 1.5 1.2917636828588928 … 0.69142104528754 0.6300249117929883], nothing, nothing, nothing, nothing)
 L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}(0.0:0.1:10.0, [0.5 0.5317050732862075 … 9.400596786582572 8.839104021426882; 1.0 0.7799498910330318 … 1.5170828943237389 2.8261901751675946], nothing, nothing, nothing, nothing)
 L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}(0.0:0.1:10.0, [1.0 1.1102743524476706 … 0.7488136202548061 0.7682836073860905; 0.5 0.4115572142804906 … 1.388294474469438 1.109323824809898], nothing, nothing, nothing, nothing)
 L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}(0.0:0.1:10.0, [0.5 0.5557572700553828 … 9.1066899353092 9.921688111088129; 0.5 0.390449424650402 … 0.45592804826244704 0.878791838771602], nothing, nothing, nothing, nothing)
 L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}(0.0:0.1:10.0, [2.0 2.112390154025954 … 3.4735620076799343 3.730105807744159; 1.0 0.9097494017873065 … 0.7655248264107148 0.8130047563338593], nothing, nothing, nothing, nothing)
 L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}(0.0:0.1:10.0, [1.0 0.9693256296130461 … 1.0173287625901066 0.9793042282176283; 2.0 1.634450182452438 … 2.0750297435191234 1.6980752778375598], nothing, nothing, nothing, nothing)
 L2Loss{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Matrix{Float64}, Nothing, Nothing, Nothing}(0.0:0.1:10.0, [2.0 1.9214830168796073 … 2.2793712619034037 2.431047146519131; 2.0 1.8019540594630477 … 0.8848252572380416 0.8294783193137137], nothing, nothing, nothing, nothing)

So losses[i] is a function which computes the loss of a solution against the data of the ith trajectory. So to build our true loss function, we sum the losses:

loss(sim) = sum(losses[i](sim[i]) for i in 1:N)
loss (generic function with 1 method)

As a double check, make sure that loss(sim) outputs zero (since we generated the data from sim). Now we generate data with other parameters:

prob = ODEProblem(pf_func, [1.0, 1.0], (0.0, 10.0), [1.2, 0.8])
function prob_func(prob, i, repeat)
    ODEProblem(prob.f, initial_conditions[i], prob.tspan, prob.p)
end
enprob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(enprob, Tsit5(), trajectories = N, saveat = data_times)
loss(sim)
10108.69414420129

and get a non-zero loss. So, we now have our problem, our data, and our loss function… we have what we need.

Put this into buildlossobjective.

obj = build_loss_objective(enprob, Tsit5(), loss, Optimization.AutoForwardDiff(),
                           trajectories = N,
                           saveat = data_times)
(::SciMLBase.OptimizationFunction{true, ADTypes.AutoForwardDiff{nothing, Nothing}, DiffEqParamEstim.var"#29#30"{Nothing, typeof(DiffEqParamEstim.STANDARD_PROB_GENERATOR), Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol}, @NamedTuple{trajectories::Int64, saveat::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}}, SciMLBase.EnsembleProblem{SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, SciMLBase.ODEFunction{true, SciMLBase.AutoSpecialize, typeof(Main.pf_func), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, typeof(Main.prob_func), typeof(SciMLBase.DEFAULT_OUTPUT_FUNC), typeof(SciMLBase.DEFAULT_REDUCTION), Nothing}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, typeof(Main.loss), Nothing, Tuple{}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}) (generic function with 1 method)

Notice that we added the kwargs for solve of the EnsembleProblem into this. They get passed to the internal solve command, so then the loss is computed on N trajectories at data_times.

Thus, we take this objective function over to any optimization package. Here, since the Lotka-Volterra equation requires positive parameters, we use Fminbox to make sure the parameters stay within the passed bounds. Let's start the optimization with [1.3,0.9], Optim spits out that the true parameters are:

lower = zeros(2)
upper = fill(2.0, 2)
optprob = OptimizationProblem(obj, [1.3, 0.9], lb = lower, ub = upper)
result = solve(optprob, Fminbox(BFGS()))
retcode: Success
u: 2-element Vector{Float64}:
 1.500464620644685
 1.0010017675395355
result
retcode: Success
u: 2-element Vector{Float64}:
 1.500464620644685
 1.0010017675395355

Optim finds one but not the other parameter.

It is advised to run a test on synthetic data for your problem before using it on real data. Maybe play around with different optimization packages, or add regularization. You may also want to decrease the tolerance of the ODE solvers via

obj = build_loss_objective(enprob, Tsit5(), loss, Optimization.AutoForwardDiff(),
                           trajectories = N,
                           abstol = 1e-8, reltol = 1e-8,
                           saveat = data_times)
optprob = OptimizationProblem(obj, [1.3, 0.9], lb = lower, ub = upper)
result = solve(optprob, BFGS()) #OptimizationOptimJL detects that it's a box constrained problem and use Fminbox wrapper over BFGS
retcode: Success
u: 2-element Vector{Float64}:
 1.5007432843504753
 1.0012380201889803
result
retcode: Success
u: 2-element Vector{Float64}:
 1.5007432843504753
 1.0012380201889803

if you suspect error is the problem. However, if you're having problems it's most likely not the ODE solver tolerance and mostly because parameter inference is a very hard optimization problem.