Bayesian Inference of ODE

For this tutorial, we will show how to do Bayesian inference to infer the parameters of the Lotka-Volterra equations using each of the three backends:

  • Turing.jl
  • Stan.jl
  • DynamicHMC.jl

Setup

First, let's set up our ODE and the data. For the data, we will simply solve the ODE and take that solution at some known parameters as the dataset. This looks like the following:

using DiffEqBayes, ParameterizedFunctions, OrdinaryDiffEq, RecursiveArrayTools,
      Distributions
f1 = @ode_def LotkaVolterra begin
    dx = a * x - x * y
    dy = -3 * y + x * y
end a

p = [1.5]
u0 = [1.0, 1.0]
tspan = (0.0, 10.0)
prob1 = ODEProblem(f1, u0, tspan, p)

σ = 0.01                         # noise, fixed for now
t = collect(1.0:10.0)   # observation times
sol = solve(prob1, Tsit5())
priors = [Normal(1.5, 1)]
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)])
data = convert(Array, randomized)
2×10 Matrix{Float64}:
 2.78509   6.79573  0.983602  1.88672  …  4.35144  3.23602  1.03388
 0.258851  2.01545  1.90413   0.29993     0.30229  4.54199  0.922815

Inference Methods

Stan

using StanSample #required for using the Stan backend
bayesian_result_stan = stan_inference(prob1, :rk45, t, data, priors)
Chains MCMC chain (1000×3×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
parameters        = sigma1.1, sigma1.2, theta_1
internals         = 

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

    sigma1.1    0.2727    0.0886    0.0036   740.9726   518.4954    1.0005     ⋯
    sigma1.2    0.2524    0.0734    0.0025   906.2566   865.5821    1.0013     ⋯
     theta_1    1.5007    0.0056    0.0002   731.8775   528.9275    1.0002     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

    sigma1.1    0.1522    0.2118    0.2532    0.3177    0.4932
    sigma1.2    0.1443    0.2007    0.2390    0.2927    0.4287
     theta_1    1.4902    1.4971    1.5006    1.5037    1.5129

Turing

bayesian_result_turing = turing_inference(prob1, Tsit5(), t, data, priors)
Chains MCMC chain (1000×16×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 17.39 seconds
Compute duration  = 17.39 seconds
parameters        = theta[1], σ[1]
internals         = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, logprior, loglikelihood, logjoint

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   e ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64     ⋯

    theta[1]    1.5004    0.0033    0.0001   636.6010   655.4540    0.9998     ⋯
        σ[1]    0.1497    0.0332    0.0017   404.1485   467.8647    1.0003     ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

    theta[1]    1.4940    1.4983    1.5003    1.5025    1.5069
        σ[1]    0.1010    0.1246    0.1461    0.1689    0.2255

DynamicHMC

We can use DynamicHMC.jl as the backend for sampling with the dynamic_inference function. It is similarly used as follows:

bayesian_result_hmc = dynamichmc_inference(prob1, Tsit5(), t, data, priors)
(posterior = [(parameters = [1.4997543058535034], σ = [0.018624318374934237, 0.013149952233093079]), (parameters = [1.4997856831953489], σ = [0.023377631356251455, 0.010790427137469935]), (parameters = [1.5005949585418168], σ = [0.009769845829391895, 0.015622805596559382]), (parameters = [1.5002889204489223], σ = [0.014785780534754699, 0.014068801531503985]), (parameters = [1.5003500653473483], σ = [0.015927126131322272, 0.014880117930766894]), (parameters = [1.4999729941289626], σ = [0.014088756533123916, 0.012386846350452356]), (parameters = [1.4998115663537546], σ = [0.021096392339439968, 0.011869449360976907]), (parameters = [1.5000507300186552], σ = [0.018351028474752975, 0.012296577652957927]), (parameters = [1.5004989620743518], σ = [0.01847542926348415, 0.011806682259911655]), (parameters = [1.500680485150213], σ = [0.012117804687013364, 0.013900650029332136])  …  (parameters = [1.499791699713688], σ = [0.010953584152830419, 0.013716171252708511]), (parameters = [1.5001291195174815], σ = [0.02052832188839729, 0.007906051280128766]), (parameters = [1.5001291195174815], σ = [0.02052832188839729, 0.007906051280128766]), (parameters = [1.4996329819329157], σ = [0.016182089528427663, 0.011052533967780961]), (parameters = [1.4999152240610236], σ = [0.015479443945738745, 0.00959468465863116]), (parameters = [1.5000302204567229], σ = [0.018429185066267062, 0.010863329750606209]), (parameters = [1.5002189993492436], σ = [0.014565455748144569, 0.011394323989621535]), (parameters = [1.5004775252705636], σ = [0.010209880244058552, 0.012953886121757843]), (parameters = [1.5006151454559915], σ = [0.01960155822719957, 0.013004009485456417]), (parameters = [1.5003042553104955], σ = [0.008944216954150765, 0.01712437044765053])], posterior_matrix = [0.4053012985944542 0.4053222200303818 … 0.40587512101204565 0.40566792441321065; -3.983287112705179 -3.7559756387526333 … -3.932146214518 -4.716748105914711; -4.331337152829366 -4.529095914075788 … -4.342497547114689 -4.0672526576905845], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStatisticsNUTS(43.93042917425763, 3, turning at positions 0:7, 0.9999999999999999, 7, DynamicHMC.Directions(0xce9cfc37)), DynamicHMC.TreeStatisticsNUTS(41.984888755374406, 2, turning at positions 0:3, 0.8787393709280025, 3, DynamicHMC.Directions(0x9bfa1923)), DynamicHMC.TreeStatisticsNUTS(43.18502175646019, 3, turning at positions 0:7, 0.9643327618929964, 7, DynamicHMC.Directions(0x3dec5527)), DynamicHMC.TreeStatisticsNUTS(44.756441710651295, 2, turning at positions 4:7, 0.9999999999999999, 7, DynamicHMC.Directions(0xd951c8a7)), DynamicHMC.TreeStatisticsNUTS(45.106380097724845, 2, turning at positions -2:1, 0.9813746799189604, 3, DynamicHMC.Directions(0x97518ae5)), DynamicHMC.TreeStatisticsNUTS(44.09504287532519, 2, turning at positions -3:0, 0.948480407974916, 3, DynamicHMC.Directions(0xd4986958)), DynamicHMC.TreeStatisticsNUTS(42.77381832065646, 2, turning at positions -3:0, 0.8225774443508773, 3, DynamicHMC.Directions(0x33fb13b4)), DynamicHMC.TreeStatisticsNUTS(43.52069631646491, 3, turning at positions -4:3, 0.9939190228311213, 7, DynamicHMC.Directions(0x684ce5bb)), DynamicHMC.TreeStatisticsNUTS(42.615034295235944, 1, turning at positions -1:-2, 0.8371166912709614, 3, DynamicHMC.Directions(0xdbec47c5)), DynamicHMC.TreeStatisticsNUTS(40.696219128235455, 2, turning at positions -3:0, 0.6077670966879958, 3, DynamicHMC.Directions(0x704e11e8))  …  DynamicHMC.TreeStatisticsNUTS(42.82904312267684, 2, turning at positions -3:0, 0.9145152108696429, 3, DynamicHMC.Directions(0xb8f535c4)), DynamicHMC.TreeStatisticsNUTS(39.92768561007369, 2, turning at positions -1:2, 0.9999999999999999, 3, DynamicHMC.Directions(0x1843613a)), DynamicHMC.TreeStatisticsNUTS(35.91480876581405, 2, turning at positions -2:1, 0.1321028557684636, 3, DynamicHMC.Directions(0x0ab76071)), DynamicHMC.TreeStatisticsNUTS(40.49069994629687, 2, turning at positions 0:3, 0.8869006854426211, 3, DynamicHMC.Directions(0x18b6bbcf)), DynamicHMC.TreeStatisticsNUTS(43.22558899204943, 2, turning at positions -2:1, 0.9975233783348295, 3, DynamicHMC.Directions(0x90340625)), DynamicHMC.TreeStatisticsNUTS(43.883485122624776, 2, turning at positions -1:2, 0.9893296638053138, 3, DynamicHMC.Directions(0xa4d9c0f6)), DynamicHMC.TreeStatisticsNUTS(43.906677965022375, 2, turning at positions -2:1, 0.9249692045452201, 3, DynamicHMC.Directions(0x1f7d8add)), DynamicHMC.TreeStatisticsNUTS(44.92230681147872, 2, turning at positions 1:4, 0.9659774916340809, 7, DynamicHMC.Directions(0x4ba20a24)), DynamicHMC.TreeStatisticsNUTS(42.939719123738044, 2, turning at positions -2:1, 0.5284422334510198, 3, DynamicHMC.Directions(0x4ff66335)), DynamicHMC.TreeStatisticsNUTS(42.560055114349495, 2, turning at positions 0:3, 0.9205130571022625, 3, DynamicHMC.Directions(0xcb964c33))], logdensities = [44.050868296474086, 43.15670534355634, 44.88219830365673, 45.513145095718926, 45.12809977834142, 45.26248307029741, 43.79555831914792, 44.69207274010756, 44.02808611789597, 44.960488573217376  …  43.22151145558013, 41.350752625876325, 41.350752625876325, 43.78065327772511, 44.37917947060558, 44.48255186917711, 45.462778469922426, 45.51239345417329, 43.548012305380816, 44.4619535634818], κ = Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [0.000202228744871263, 0.28074854739102056, 0.2756410835521548], ϵ = 0.7582339735838659)

More Information

For a better idea of the summary statistics and plotting, you can take a look at the benchmarks.