Bayesian Inference of ODE
For this tutorial, we will show how to do Bayesian inference to infer the parameters of the Lotka-Volterra equations using each of the three backends:
- Turing.jl
- Stan.jl
- DynamicHMC.jl
Setup
First, let's set up our ODE and the data. For the data, we will simply solve the ODE and take that solution at some known parameters as the dataset. This looks like the following:
using DiffEqBayes, ParameterizedFunctions, OrdinaryDiffEq, RecursiveArrayTools,
Distributions
f1 = @ode_def LotkaVolterra begin
dx = a * x - x * y
dy = -3 * y + x * y
end a
p = [1.5]
u0 = [1.0, 1.0]
tspan = (0.0, 10.0)
prob1 = ODEProblem(f1, u0, tspan, p)
σ = 0.01 # noise, fixed for now
t = collect(1.0:10.0) # observation times
sol = solve(prob1, Tsit5())
priors = [Normal(1.5, 1)]
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)])
data = convert(Array, randomized)2×10 Matrix{Float64}:
2.78332 6.76884 0.934643 1.86953 … 4.33951 3.2407 1.02192
0.277205 2.02905 1.91783 0.310527 0.322532 4.54504 0.904309Inference Methods
Stan
using StanSample #required for using the Stan backend
bayesian_result_stan = stan_inference(prob1, :rk45, t, data, priors)1000×3 DataFrame
Row │ sigma1.1 sigma1.2 theta_1
│ Float64 Float64 Float64
──────┼─────────────────────────────
1 │ 0.393118 0.315827 1.50132
2 │ 0.282397 0.305268 1.49512
3 │ 0.150378 0.316076 1.5015
4 │ 0.244083 0.206386 1.50834
5 │ 0.205996 0.20982 1.50453
6 │ 0.214056 0.265647 1.49495
7 │ 0.277271 0.2535 1.50538
8 │ 0.327741 0.177529 1.49507
⋮ │ ⋮ ⋮ ⋮
994 │ 0.423166 0.19338 1.488
995 │ 0.290926 0.194671 1.49034
996 │ 0.309321 0.175777 1.50054
997 │ 0.274246 0.31086 1.50191
998 │ 0.311756 0.286311 1.49756
999 │ 0.216036 0.223427 1.50319
1000 │ 0.287365 0.241683 1.5071
985 rows omittedTuring
bayesian_result_turing = turing_inference(prob1, Tsit5(), t, data, priors)╭─FlexiChain (1000 iterations, 1 chain) ───────────────────────────────────────╮
│ ↓ iter = 501:1500 │
│ → chain = 1:1 │
│ │
│ Parameters (2) ── AbstractPPL.VarName │
│ Float64 theta[1] │
│ Vector{Float64} σ │
│ │
│ Extras (14) │
│ Int64 n_steps, tree_depth │
│ Bool is_accept, numerical_error │
│ Float64 acceptance_rate, log_density, hamiltonian_energy, │
│ hamiltonian_energy_error, max_hamiltonian_energy_error, step_size, │
│ nom_step_size, logprior, loglikelihood, logjoint │
╰──────────────────────────────────────────────────────────────────────────────╯DynamicHMC
We can use DynamicHMC.jl as the backend for sampling with the dynamic_inference function. It is similarly used as follows:
bayesian_result_hmc = dynamichmc_inference(prob1, Tsit5(), t, data, priors)(posterior = [(parameters = [1.5000213177037522], σ = [0.023253751383379234, 0.013997387652595767]), (parameters = [1.5000955830002216], σ = [0.01738258884846736, 0.013610722787935523]), (parameters = [1.5006999107506833], σ = [0.01649713206591352, 0.009288438732174746]), (parameters = [1.5004764215185882], σ = [0.016302986391185468, 0.012800654919947644]), (parameters = [1.50033078737305], σ = [0.012446615831101331, 0.01043013102234103]), (parameters = [1.5007243281974187], σ = [0.024212100195889895, 0.013592458184098889]), (parameters = [1.500242277873348], σ = [0.010840909610318807, 0.013551173337686432]), (parameters = [1.5004041617271529], σ = [0.008861753265797117, 0.013251647991876583]), (parameters = [1.5005528319156425], σ = [0.011873491350602054, 0.0129286511244259]), (parameters = [1.5005987501834916], σ = [0.01618706305402008, 0.014235478722677258]) … (parameters = [1.5000744927438012], σ = [0.01436852759529074, 0.017122152739777137]), (parameters = [1.5005800050724218], σ = [0.017026563675849815, 0.012083428876292702]), (parameters = [1.5006878752538255], σ = [0.01580307414722355, 0.014706230143475454]), (parameters = [1.5006893166950168], σ = [0.03233669894510641, 0.01345736826202947]), (parameters = [1.5001428696211723], σ = [0.017140684783862006, 0.009465518672288994]), (parameters = [1.5006203307692596], σ = [0.01815544959877971, 0.018154500525228448]), (parameters = [1.5002087285947125], σ = [0.01638939877226835, 0.009559551944444277]), (parameters = [1.5006710273906985], σ = [0.01515888080217509, 0.014300168691643942]), (parameters = [1.4998132795289638], σ = [0.01745551414418913, 0.012334380933703119]), (parameters = [1.500048525239224], σ = [0.026164243194566412, 0.010165383999595542])], posterior_matrix = [0.4054793198096791 0.40552882807815166 … 0.405340620045823 0.40549745774439183; -3.7612888101414916 -4.052286214629658 … -4.0480996834485055 -3.6433615639192443; -4.268884563021283 -4.296897356606521 … -4.395364718042745 -4.588767055970433], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStatisticsNUTS(41.06793543511067, 2, turning at positions 4:7, 0.9745665064252346, 7, DynamicHMC.Directions(0x48d3d6af)), DynamicHMC.TreeStatisticsNUTS(41.78079621883924, 2, turning at positions -2:1, 0.9858309819429635, 3, DynamicHMC.Directions(0x66d5f0d9)), DynamicHMC.TreeStatisticsNUTS(40.29413934664708, 2, turning at positions -1:-4, 0.7654774601277642, 7, DynamicHMC.Directions(0xd7adb643)), DynamicHMC.TreeStatisticsNUTS(42.84236886859498, 3, turning at positions -3:4, 0.9856923484368415, 7, DynamicHMC.Directions(0x1aba21dc)), DynamicHMC.TreeStatisticsNUTS(42.09262425572443, 2, turning at positions -1:2, 0.837852608032969, 3, DynamicHMC.Directions(0x8bfb3812)), DynamicHMC.TreeStatisticsNUTS(42.01408276211967, 3, turning at positions -5:2, 0.9478865267732374, 7, DynamicHMC.Directions(0xacc86a62)), DynamicHMC.TreeStatisticsNUTS(41.63256893938838, 2, turning at positions 0:3, 0.9186055699560702, 3, DynamicHMC.Directions(0xbbe7d643)), DynamicHMC.TreeStatisticsNUTS(38.443301679986234, 1, turning at positions -1:0, 0.3268793018370503, 1, DynamicHMC.Directions(0xc4932a8e)), DynamicHMC.TreeStatisticsNUTS(40.56243936106751, 3, turning at positions 0:7, 0.9999999999999999, 7, DynamicHMC.Directions(0x6d881c37)), DynamicHMC.TreeStatisticsNUTS(42.49200560233446, 2, turning at positions -1:2, 0.9717473526942718, 3, DynamicHMC.Directions(0x9440bade)) … DynamicHMC.TreeStatisticsNUTS(41.523584968538756, 1, turning at positions -1:0, 1.0, 1, DynamicHMC.Directions(0xd09c594e)), DynamicHMC.TreeStatisticsNUTS(42.294523364484014, 2, turning at positions -1:2, 0.9999999999999999, 3, DynamicHMC.Directions(0xc4e45bce)), DynamicHMC.TreeStatisticsNUTS(43.44401369811263, 3, turning at positions -1:6, 0.9694938368794516, 7, DynamicHMC.Directions(0x3c63c46e)), DynamicHMC.TreeStatisticsNUTS(40.52798239896049, 2, turning at positions 0:3, 0.7070469510452084, 3, DynamicHMC.Directions(0x0670932b)), DynamicHMC.TreeStatisticsNUTS(39.935437912000936, 2, turning at positions 4:7, 0.992376351908459, 7, DynamicHMC.Directions(0x46d110b7)), DynamicHMC.TreeStatisticsNUTS(39.19291899292402, 3, turning at positions -2:5, 0.7505715124342273, 7, DynamicHMC.Directions(0x2ab767a5)), DynamicHMC.TreeStatisticsNUTS(42.09207444581346, 3, turning at positions 0:7, 0.990251179631354, 7, DynamicHMC.Directions(0x634b2447)), DynamicHMC.TreeStatisticsNUTS(42.924492689144216, 3, turning at positions -4:3, 0.9999999999999999, 7, DynamicHMC.Directions(0xe1f5429b)), DynamicHMC.TreeStatisticsNUTS(41.158264749209636, 3, turning at positions 0:7, 0.8085464271901267, 7, DynamicHMC.Directions(0xc120279f)), DynamicHMC.TreeStatisticsNUTS(40.249351323611094, 2, turning at positions 4:7, 0.9505342679527207, 7, DynamicHMC.Directions(0xf61f7f17))], logdensities = [42.27246089440389, 43.282792404779975, 43.17864267980074, 44.06707148006278, 43.07846192356003, 42.642177051537054, 41.79843091693122, 39.265162066659215, 43.081525676864544, 43.820141934673664 … 42.30725431523882, 44.03543952380514, 43.60036380054761, 40.979877059766885, 42.73916128176054, 42.71837594204497, 43.05114130391705, 43.67401777334369, 42.052188859173796, 41.431326897022316], κ = Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [0.00023816733174924365, 0.2811056681686239, 0.21815742051855916], ϵ = 0.6947482357803328)More Information
For a better idea of the summary statistics and plotting, you can take a look at the benchmarks.