Bayesian Inference of ODE
For this tutorial, we will show how to do Bayesian inference to infer the parameters of the Lotka-Volterra equations using each of the three backends:
- Turing.jl
- Stan.jl
- DynamicHMC.jl
Setup
First, let's set up our ODE and the data. For the data, we will simply solve the ODE and take that solution at some known parameters as the dataset. This looks like the following:
using DiffEqBayes, ParameterizedFunctions, OrdinaryDiffEq, RecursiveArrayTools,
Distributions
f1 = @ode_def LotkaVolterra begin
dx = a * x - x * y
dy = -3 * y + x * y
end a
p = [1.5]
u0 = [1.0, 1.0]
tspan = (0.0, 10.0)
prob1 = ODEProblem(f1, u0, tspan, p)
σ = 0.01 # noise, fixed for now
t = collect(1.0:10.0) # observation times
sol = solve(prob1, Tsit5())
priors = [Normal(1.5, 1)]
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)])
data = convert(Array, randomized)2×10 Matrix{Float64}:
2.7793 6.76565 0.977344 1.89394 … 4.36172 3.25082 1.02621
0.241136 2.0184 1.90375 0.31899 0.316162 4.5356 0.899286Inference Methods
Stan
using StanSample #required for using the Stan backend
bayesian_result_stan = stan_inference(prob1, :rk45, t, data, priors)Chains MCMC chain (1000×3×1 Array{Float64, 3}):
Iterations = 1:1:1000
Number of chains = 1
Samples per chain = 1000
parameters = sigma1.1, sigma1.2, theta_1
internals =
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat e ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
sigma1.1 0.2662 0.0768 0.0033 625.4736 664.9113 1.0035 ⋯
sigma1.2 0.2543 0.0756 0.0029 795.9807 684.2835 0.9994 ⋯
theta_1 1.5012 0.0057 0.0002 776.1390 419.0812 1.0008 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
sigma1.1 0.1588 0.2121 0.2536 0.3092 0.4430
sigma1.2 0.1522 0.1973 0.2400 0.2932 0.4350
theta_1 1.4903 1.4976 1.5010 1.5043 1.5131
Turing
bayesian_result_turing = turing_inference(prob1, Tsit5(), t, data, priors)Chains MCMC chain (1000×16×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 37.19 seconds
Compute duration = 37.19 seconds
parameters = theta[1], σ[1]
internals = n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, logprior, loglikelihood, logjoint
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat e ⋯
Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
theta[1] 1.5004 0.0035 0.0001 932.8267 491.2511 1.0123 ⋯
σ[1] 0.1522 0.0343 0.0017 381.2644 534.3193 1.0007 ⋯
1 column omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
theta[1] 1.4937 1.4980 1.5004 1.5028 1.5072
σ[1] 0.1023 0.1266 0.1477 0.1709 0.2313
DynamicHMC
We can use DynamicHMC.jl as the backend for sampling with the dynamic_inference function. It is similarly used as follows:
bayesian_result_hmc = dynamichmc_inference(prob1, Tsit5(), t, data, priors)(posterior = [(parameters = [1.500055442201442], σ = [0.013010775771617914, 0.011312680472758463]), (parameters = [1.5001993307010475], σ = [0.012307011859823374, 0.010166826599515996]), (parameters = [1.5002427136423868], σ = [0.015103357274481791, 0.01655026530508954]), (parameters = [1.5000679498713714], σ = [0.012701593028661088, 0.01826757325330891]), (parameters = [1.4998749388776482], σ = [0.013349451286214198, 0.016998286783379717]), (parameters = [1.500209890171037], σ = [0.01149015736368376, 0.016112682026348628]), (parameters = [1.500625121579079], σ = [0.012837987689732504, 0.011428672752945209]), (parameters = [1.5000211743175609], σ = [0.01354881344377301, 0.013160119373396853]), (parameters = [1.5005973518738505], σ = [0.014961557857606535, 0.014916471828859827]), (parameters = [1.5000798229741328], σ = [0.021414203530910034, 0.012053884804998412]) … (parameters = [1.500368032567741], σ = [0.0156253247251225, 0.012298126747781883]), (parameters = [1.5002343737782387], σ = [0.013606400402872984, 0.015924148815228363]), (parameters = [1.5000363985695198], σ = [0.011950037543321235, 0.012641791790925814]), (parameters = [1.5006875278379839], σ = [0.015544919653251395, 0.012723843069226741]), (parameters = [1.4999817576651662], σ = [0.011491286788563605, 0.013729431156638916]), (parameters = [1.5005366902691148], σ = [0.021584399712293017, 0.016484062699189516]), (parameters = [1.5004522107933607], σ = [0.02223291100264938, 0.018672172118139995]), (parameters = [1.5008109675151315], σ = [0.018803669613752547, 0.02020531731404508]), (parameters = [1.5007509576902984], σ = [0.01595135778057034, 0.01916456978051946]), (parameters = [1.5006129840476632], σ = [0.01847168983963352, 0.017765541161208667])], posterior_matrix = [0.4055020688927342 0.405597986413483 … 0.40596562129074015 0.40587368066280527; -4.341977359363977 -4.3975861091275465 … -4.138211326065922 -3.9915159979092865; -4.481831016690333 -4.588625153060679 … -3.9546920285873663 -4.0304945877604075], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStatisticsNUTS(39.58738912269798, 2, turning at positions 0:3, 0.8143733823240016, 3, DynamicHMC.Directions(0x97ec6e63)), DynamicHMC.TreeStatisticsNUTS(43.19376183770021, 3, turning at positions -4:3, 0.9816283537776268, 7, DynamicHMC.Directions(0x8a08f91b)), DynamicHMC.TreeStatisticsNUTS(43.151851234713945, 2, turning at positions 0:3, 0.9920149943663749, 3, DynamicHMC.Directions(0x35341a77)), DynamicHMC.TreeStatisticsNUTS(42.71586274019171, 3, turning at positions -2:5, 0.9595536408109072, 7, DynamicHMC.Directions(0x4ec124a5)), DynamicHMC.TreeStatisticsNUTS(42.72344135210901, 2, turning at positions 3:6, 0.9756900798872218, 7, DynamicHMC.Directions(0x6e1c7b26)), DynamicHMC.TreeStatisticsNUTS(42.332341870012705, 2, turning at positions -1:2, 0.9282857874451699, 3, DynamicHMC.Directions(0x8491798a)), DynamicHMC.TreeStatisticsNUTS(43.01065390757874, 3, turning at positions -6:1, 0.9508473474842455, 7, DynamicHMC.Directions(0x1dc9b459)), DynamicHMC.TreeStatisticsNUTS(41.373179073198116, 2, turning at positions -3:0, 0.8651044619122832, 3, DynamicHMC.Directions(0x77cedbe8)), DynamicHMC.TreeStatisticsNUTS(43.35770505032152, 3, turning at positions -1:6, 0.9518019408023234, 7, DynamicHMC.Directions(0x324863be)), DynamicHMC.TreeStatisticsNUTS(42.2657560054173, 2, turning at positions 1:4, 0.8731281881256722, 7, DynamicHMC.Directions(0xb905f6cc)) … DynamicHMC.TreeStatisticsNUTS(43.722234556821434, 3, turning at positions -5:2, 0.9815394399099926, 7, DynamicHMC.Directions(0xc0b09bc2)), DynamicHMC.TreeStatisticsNUTS(42.60135157780331, 3, turning at positions -6:1, 0.8706136972385218, 7, DynamicHMC.Directions(0x776a7d09)), DynamicHMC.TreeStatisticsNUTS(43.556844543164274, 2, turning at positions 4:7, 0.9493034950804973, 7, DynamicHMC.Directions(0xa4bcb637)), DynamicHMC.TreeStatisticsNUTS(42.58105147095735, 3, turning at positions -6:1, 0.9456035895388828, 7, DynamicHMC.Directions(0x29bd3501)), DynamicHMC.TreeStatisticsNUTS(42.42360960885033, 3, turning at positions -6:1, 0.9364758205329953, 7, DynamicHMC.Directions(0xe3bf5001)), DynamicHMC.TreeStatisticsNUTS(42.21600992539078, 3, turning at positions -5:2, 0.9150671025729856, 7, DynamicHMC.Directions(0x3b7d964a)), DynamicHMC.TreeStatisticsNUTS(41.88325203019997, 2, turning at positions -2:1, 0.9836368360689506, 3, DynamicHMC.Directions(0xbf953bed)), DynamicHMC.TreeStatisticsNUTS(40.27222678883203, 2, turning at positions -2:-5, 0.9712302112925445, 7, DynamicHMC.Directions(0x675afdb2)), DynamicHMC.TreeStatisticsNUTS(41.60093157471195, 2, turning at positions -1:2, 0.9999999999999999, 3, DynamicHMC.Directions(0x1cc01f0e)), DynamicHMC.TreeStatisticsNUTS(42.08855121418456, 2, turning at positions -1:2, 0.9983140644636697, 3, DynamicHMC.Directions(0xa071e752))], logdensities = [43.836073161519224, 43.57558218950759, 44.09517853920816, 43.492172383510855, 43.12264445198452, 43.92855369925367, 43.61415154829532, 44.08690186175364, 43.958819186344485, 42.86441233748267 … 44.35273119504118, 44.287685943474685, 43.85274484738363, 43.74238754959198, 43.504304395173534, 42.67088080712449, 42.10613245066705, 41.805729881356555, 42.551918169822194, 42.954166567603174], κ = Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [0.00021594953050623869, 0.2700898894046389, 0.2523398098452484], ϵ = 0.672971270618368)More Information
For a better idea of the summary statistics and plotting, you can take a look at the benchmarks.