Bayesian Inference of ODE

For this tutorial, we will show how to do Bayesian inference to infer the parameters of the Lotka-Volterra equations using each of the three backends:

  • Turing.jl
  • Stan.jl
  • DynamicHMC.jl

Setup

First, let's set up our ODE and the data. For the data, we will simply solve the ODE and take that solution at some known parameters as the dataset. This looks like the following:

using DiffEqBayes, ParameterizedFunctions, OrdinaryDiffEq, RecursiveArrayTools,
      Distributions
f1 = @ode_def LotkaVolterra begin
    dx = a * x - x * y
    dy = -3 * y + x * y
end a

p = [1.5]
u0 = [1.0, 1.0]
tspan = (0.0, 10.0)
prob1 = ODEProblem(f1, u0, tspan, p)

σ = 0.01                         # noise, fixed for now
t = collect(1.0:10.0)   # observation times
sol = solve(prob1, Tsit5())
priors = [Normal(1.5, 1)]
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)])
data = convert(Array, randomized)
2×10 Matrix{Float64}:
 2.78332   6.76884  0.934643  1.86953   …  4.33951   3.2407   1.02192
 0.277205  2.02905  1.91783   0.310527     0.322532  4.54504  0.904309

Inference Methods

Stan

using StanSample #required for using the Stan backend
bayesian_result_stan = stan_inference(prob1, :rk45, t, data, priors)
1000×3 DataFrame
  Row │ sigma1.1  sigma1.2  theta_1 
      │ Float64   Float64   Float64 
──────┼─────────────────────────────
    1 │ 0.393118  0.315827  1.50132
    2 │ 0.282397  0.305268  1.49512
    3 │ 0.150378  0.316076  1.5015
    4 │ 0.244083  0.206386  1.50834
    5 │ 0.205996  0.20982   1.50453
    6 │ 0.214056  0.265647  1.49495
    7 │ 0.277271  0.2535    1.50538
    8 │ 0.327741  0.177529  1.49507
  ⋮   │    ⋮         ⋮         ⋮
  994 │ 0.423166  0.19338   1.488
  995 │ 0.290926  0.194671  1.49034
  996 │ 0.309321  0.175777  1.50054
  997 │ 0.274246  0.31086   1.50191
  998 │ 0.311756  0.286311  1.49756
  999 │ 0.216036  0.223427  1.50319
 1000 │ 0.287365  0.241683  1.5071
                    985 rows omitted

Turing

bayesian_result_turing = turing_inference(prob1, Tsit5(), t, data, priors)
╭─FlexiChain (1000 iterations, 1 chain) ───────────────────────────────────────╮
 ↓ iter  = 501:1500                                                           
 → chain = 1:1                                                                
                                                                              
 Parameters (2) ── AbstractPPL.VarName                                        
  Float64          theta[1]                                                   
  Vector{Float64}  σ                                                          
                                                                              
 Extras (14)                                                                  
  Int64    n_steps, tree_depth                                                
  Bool     is_accept, numerical_error                                         
  Float64  acceptance_rate, log_density, hamiltonian_energy,                  
           hamiltonian_energy_error, max_hamiltonian_energy_error, step_size, 
           nom_step_size, logprior, loglikelihood, logjoint                   
╰──────────────────────────────────────────────────────────────────────────────╯

DynamicHMC

We can use DynamicHMC.jl as the backend for sampling with the dynamic_inference function. It is similarly used as follows:

bayesian_result_hmc = dynamichmc_inference(prob1, Tsit5(), t, data, priors)
(posterior = [(parameters = [1.5000213177037522], σ = [0.023253751383379234, 0.013997387652595767]), (parameters = [1.5000955830002216], σ = [0.01738258884846736, 0.013610722787935523]), (parameters = [1.5006999107506833], σ = [0.01649713206591352, 0.009288438732174746]), (parameters = [1.5004764215185882], σ = [0.016302986391185468, 0.012800654919947644]), (parameters = [1.50033078737305], σ = [0.012446615831101331, 0.01043013102234103]), (parameters = [1.5007243281974187], σ = [0.024212100195889895, 0.013592458184098889]), (parameters = [1.500242277873348], σ = [0.010840909610318807, 0.013551173337686432]), (parameters = [1.5004041617271529], σ = [0.008861753265797117, 0.013251647991876583]), (parameters = [1.5005528319156425], σ = [0.011873491350602054, 0.0129286511244259]), (parameters = [1.5005987501834916], σ = [0.01618706305402008, 0.014235478722677258])  …  (parameters = [1.5000744927438012], σ = [0.01436852759529074, 0.017122152739777137]), (parameters = [1.5005800050724218], σ = [0.017026563675849815, 0.012083428876292702]), (parameters = [1.5006878752538255], σ = [0.01580307414722355, 0.014706230143475454]), (parameters = [1.5006893166950168], σ = [0.03233669894510641, 0.01345736826202947]), (parameters = [1.5001428696211723], σ = [0.017140684783862006, 0.009465518672288994]), (parameters = [1.5006203307692596], σ = [0.01815544959877971, 0.018154500525228448]), (parameters = [1.5002087285947125], σ = [0.01638939877226835, 0.009559551944444277]), (parameters = [1.5006710273906985], σ = [0.01515888080217509, 0.014300168691643942]), (parameters = [1.4998132795289638], σ = [0.01745551414418913, 0.012334380933703119]), (parameters = [1.500048525239224], σ = [0.026164243194566412, 0.010165383999595542])], posterior_matrix = [0.4054793198096791 0.40552882807815166 … 0.405340620045823 0.40549745774439183; -3.7612888101414916 -4.052286214629658 … -4.0480996834485055 -3.6433615639192443; -4.268884563021283 -4.296897356606521 … -4.395364718042745 -4.588767055970433], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStatisticsNUTS(41.06793543511067, 2, turning at positions 4:7, 0.9745665064252346, 7, DynamicHMC.Directions(0x48d3d6af)), DynamicHMC.TreeStatisticsNUTS(41.78079621883924, 2, turning at positions -2:1, 0.9858309819429635, 3, DynamicHMC.Directions(0x66d5f0d9)), DynamicHMC.TreeStatisticsNUTS(40.29413934664708, 2, turning at positions -1:-4, 0.7654774601277642, 7, DynamicHMC.Directions(0xd7adb643)), DynamicHMC.TreeStatisticsNUTS(42.84236886859498, 3, turning at positions -3:4, 0.9856923484368415, 7, DynamicHMC.Directions(0x1aba21dc)), DynamicHMC.TreeStatisticsNUTS(42.09262425572443, 2, turning at positions -1:2, 0.837852608032969, 3, DynamicHMC.Directions(0x8bfb3812)), DynamicHMC.TreeStatisticsNUTS(42.01408276211967, 3, turning at positions -5:2, 0.9478865267732374, 7, DynamicHMC.Directions(0xacc86a62)), DynamicHMC.TreeStatisticsNUTS(41.63256893938838, 2, turning at positions 0:3, 0.9186055699560702, 3, DynamicHMC.Directions(0xbbe7d643)), DynamicHMC.TreeStatisticsNUTS(38.443301679986234, 1, turning at positions -1:0, 0.3268793018370503, 1, DynamicHMC.Directions(0xc4932a8e)), DynamicHMC.TreeStatisticsNUTS(40.56243936106751, 3, turning at positions 0:7, 0.9999999999999999, 7, DynamicHMC.Directions(0x6d881c37)), DynamicHMC.TreeStatisticsNUTS(42.49200560233446, 2, turning at positions -1:2, 0.9717473526942718, 3, DynamicHMC.Directions(0x9440bade))  …  DynamicHMC.TreeStatisticsNUTS(41.523584968538756, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0xd09c594e)), DynamicHMC.TreeStatisticsNUTS(42.294523364484014, 2, turning at positions -1:2, 0.9999999999999999, 3, DynamicHMC.Directions(0xc4e45bce)), DynamicHMC.TreeStatisticsNUTS(43.44401369811263, 3, turning at positions -1:6, 0.9694938368794516, 7, DynamicHMC.Directions(0x3c63c46e)), DynamicHMC.TreeStatisticsNUTS(40.52798239896049, 2, turning at positions 0:3, 0.7070469510452084, 3, DynamicHMC.Directions(0x0670932b)), DynamicHMC.TreeStatisticsNUTS(39.935437912000936, 2, turning at positions 4:7, 0.992376351908459, 7, DynamicHMC.Directions(0x46d110b7)), DynamicHMC.TreeStatisticsNUTS(39.19291899292402, 3, turning at positions -2:5, 0.7505715124342273, 7, DynamicHMC.Directions(0x2ab767a5)), DynamicHMC.TreeStatisticsNUTS(42.09207444581346, 3, turning at positions 0:7, 0.990251179631354, 7, DynamicHMC.Directions(0x634b2447)), DynamicHMC.TreeStatisticsNUTS(42.924492689144216, 3, turning at positions -4:3, 0.9999999999999999, 7, DynamicHMC.Directions(0xe1f5429b)), DynamicHMC.TreeStatisticsNUTS(41.158264749209636, 3, turning at positions 0:7, 0.8085464271901267, 7, DynamicHMC.Directions(0xc120279f)), DynamicHMC.TreeStatisticsNUTS(40.249351323611094, 2, turning at positions 4:7, 0.9505342679527207, 7, DynamicHMC.Directions(0xf61f7f17))], logdensities = [42.27246089440389, 43.282792404779975, 43.17864267980074, 44.06707148006278, 43.07846192356003, 42.642177051537054, 41.79843091693122, 39.265162066659215, 43.081525676864544, 43.820141934673664  …  42.30725431523882, 44.03543952380514, 43.60036380054761, 40.979877059766885, 42.73916128176054, 42.71837594204497, 43.05114130391705, 43.67401777334369, 42.052188859173796, 41.431326897022316], κ = Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [0.00023816733174924365, 0.2811056681686239, 0.21815742051855916], ϵ = 0.6947482357803328)

More Information

For a better idea of the summary statistics and plotting, you can take a look at the benchmarks.