using DiffEqBayes, OrdinaryDiffEq, RecursiveArrayTools, Distributions, Plots, StatsPlots, BenchmarkTools, TransformVariables, CmdStan, DynamicHMC
Let's define our simple pendulum problem. Here our pendulum has a drag term ω
and a length L
.
We get first order equations by defining the first term as the velocity and the second term as the position, getting:
function pendulum(du,u,p,t) ω,L = p x,y = u du[1] = y du[2] = - ω*y -(9.8/L)*sin(x) end u0 = [1.0,0.1] tspan = (0.0,10.0) prob1 = ODEProblem(pendulum,u0,tspan,[1.0,2.5])
ODEProblem with uType Array{Float64,1} and tType Float64. In-place: true timespan: (0.0, 10.0) u0: [1.0, 0.1]
To understand the model and generate data, let's solve and visualize the solution with the known parameters:
sol = solve(prob1,Tsit5()) plot(sol)
It's the pendulum, so you know what it looks like. It's periodic, but since we have not made a small angle assumption it's not exactly sin
or cos
. Because the true dampening parameter ω
is 1, the solution does not decay over time, nor does it increase. The length L
determines the period.
We now generate some dummy data to use for estimation
t = collect(range(1,stop=10,length=10)) randomized = VectorOfArray([(sol(t[i]) + .01randn(2)) for i in 1:length(t)]) data = convert(Array,randomized)
2×10 Array{Float64,2}: 0.0569058 -0.375845 0.136638 0.0720362 … 0.00859386 -0.000752689 -1.20656 0.334965 0.293891 -0.255306 0.0281103 -0.00753411
Let's see what our data looks like on top of the real solution
scatter!(data')
This data captures the non-dampening effect and the true period, making it perfect to attempting a Bayesian inference.
Now let's fit the pendulum to the data. Since we know our model is correct, this should give us back the parameters that we used to generate the data! Define priors on our parameters. In this case, let's assume we don't have much information, but have a prior belief that ω is between 0.1 and 3.0, while the length of the pendulum L is probably around 3.0:
priors = [Uniform(0.1,3.0), Normal(3.0,1.0)]
2-element Array{Distributions.Distribution{Distributions.Univariate,Distrib utions.Continuous},1}: Distributions.Uniform{Float64}(a=0.1, b=3.0) Distributions.Normal{Float64}(μ=3.0, σ=1.0)
Finally let's run the estimation routine from DiffEqBayes.jl with the Turing.jl backend to check if we indeed recover the parameters!
bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=10_000, syms = [:omega,:L])
Chains MCMC chain (9000×15×1 Array{Float64,3}): Iterations = 1:9000 Thinning interval = 1 Chains = 1 Samples per chain = 9000 parameters = L, omega, σ[1] internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy _error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth Summary Statistics parameters mean std naive_se mcse ess rhat Symbol Float64 Float64 Float64 Float64 Float64 Float64 L 2.4780 0.2113 0.0022 0.0032 3799.5678 1.0001 omega 1.0897 0.2204 0.0023 0.0042 2820.5866 0.9999 σ[1] 0.1596 0.0382 0.0004 0.0006 4229.9547 1.0000 Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 L 2.0473 2.3485 2.4795 2.6056 2.9082 omega 0.7743 0.9501 1.0518 1.1849 1.6446 σ[1] 0.1006 0.1332 0.1537 0.1796 0.2518
Notice that while our guesses had the wrong means, the learned parameters converged to the correct means, meaning that it learned good posterior distributions for the parameters. To look at these posterior distributions on the parameters, we can examine the chains:
plot(bayesian_result)
As a diagnostic, we will also check the parameter chains. The chain is the MCMC sampling process. The chain should explore parameter space and converge reasonably well, and we should be taking a lot of samples after it converges (it is these samples that form the posterior distribution!)
plot(bayesian_result, colordim = :parameter)
Notice that after awhile these chains converge to a "fuzzy line", meaning it found the area with the most likelihood and then starts to sample around there, which builds a posterior distribution around the true mean.
DiffEqBayes.jl allows the choice of using Stan.jl, Turing.jl and DynamicHMC.jl for MCMC, you can also use ApproxBayes.jl for Approximate Bayesian computation algorithms. Let's compare the timings across the different MCMC backends. We'll stick with the default arguments and 10,000 samples in each since there is a lot of room for micro-optimization specific to each package and algorithm combinations, you might want to do your own experiments for specific problems to get better understanding of the performance.
@btime bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;syms = [:omega,:L],num_samples=10_000)
24.699 s (102709003 allocations: 6.73 GiB) Chains MCMC chain (9000×15×1 Array{Float64,3}): Iterations = 1:9000 Thinning interval = 1 Chains = 1 Samples per chain = 9000 parameters = L, omega, σ[1] internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy _error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth Summary Statistics parameters mean std naive_se mcse ess rhat Symbol Float64 Float64 Float64 Float64 Float64 Float64 L 0.0390 0.0004 0.0000 0.0000 2733.6119 0.9999 omega 2.5574 0.3303 0.0035 0.0076 1793.3140 1.0000 σ[1] 0.2604 0.0506 0.0005 0.0011 2006.2537 1.0023 Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 L 0.0383 0.0387 0.0390 0.0393 0.0399 omega 1.8335 2.3299 2.6239 2.8334 2.9846 σ[1] 0.1841 0.2243 0.2534 0.2885 0.3758
@btime bayesian_result = stan_inference(prob1,t,data,priors;num_samples=10_000,printsummary=false)
ERROR: MethodError: no method matching iterate(::ModelingToolkit.ODESystem) Closest candidates are: iterate(!Matched::Core.SimpleVector) at essentials.jl:603 iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603 iterate(!Matched::ExponentialBackOff) at error.jl:253 ...
@btime bayesian_result = dynamichmc_inference(prob1,Tsit5(),t,data,priors;num_samples = 10_000)
8.536 s (60240867 allocations: 3.23 GiB) (posterior = NamedTuple{(:parameters, :σ),Tuple{Array{Float64,1},Array{Floa t64,1}}}[(parameters = [0.9928357991299962, 2.4814752963445783], σ = [0.012 042227011386932, 0.02277300206307592]), (parameters = [1.0035536037225914, 2.4842109006609734], σ = [0.010217641905223972, 0.025125875466338137]), (pa rameters = [1.023922990689575, 2.478168528000084], σ = [0.01138382781197780 4, 0.015334129812579608]), (parameters = [1.0360305764127264, 2.43616698724 24], σ = [0.011720550962413083, 0.01760567441308434]), (parameters = [1.035 3164419985221, 2.4608038757829234], σ = [0.014284421084486371, 0.0277579640 9788424]), (parameters = [1.0218403104607774, 2.44566016081281], σ = [0.013 188307351048846, 0.027766999301680268]), (parameters = [1.0219203068577845, 2.4680392948696266], σ = [0.010590988601370078, 0.032035260394786595]), (p arameters = [1.0337868976662015, 2.4617552238656595], σ = [0.01053414448848 5697, 0.031228924706042]), (parameters = [1.0185248367770061, 2.46063667445 5973], σ = [0.011714478681130423, 0.030502664683171674]), (parameters = [1. 0403183311137085, 2.4460976026190155], σ = [0.013623369350103115, 0.0189055 49430232045]) … (parameters = [0.9818582336469815, 2.4783464016080754], σ = [0.014337869925457273, 0.02196297446360589]), (parameters = [1.024267749 5600505, 2.4679755767540854], σ = [0.009260383114489384, 0.0160703909876819 ]), (parameters = [1.0060721849156442, 2.4636880317110124], σ = [0.00957345 263379393, 0.01664557853105612]), (parameters = [0.9990020681169247, 2.4691 265297952634], σ = [0.007544518709041809, 0.019118074991476188]), (paramete rs = [1.0260118688811353, 2.4548318757910828], σ = [0.01637648529482634, 0. 011675232895375701]), (parameters = [0.9966658558521294, 2.4695369709129857 ], σ = [0.012072926915835953, 0.016064184680234996]), (parameters = [1.0089 544699160784, 2.4868524338100477], σ = [0.010885292643205624, 0.01280301983 665544]), (parameters = [1.0234530952572136, 2.480526054452139], σ = [0.013 266901007592884, 0.01678380669943965]), (parameters = [1.0019539194447724, 2.4569538599391856], σ = [0.012832954266213337, 0.03449082821202926]), (par ameters = [1.0076309848809268, 2.4878359928799636], σ = [0.0154712406868315 87, 0.007970446556513428])], chain = [[-0.00718998698882292, 0.908853260865 6742, -4.419335888482715, -3.782179564686333], [0.0035473045915490296, 0.90 99550641148565, -4.583639454173646, -3.6838570687785817], [0.02364131938461 103, 0.9075197905796608, -4.475561543750139, -4.17767422821154], [0.0353966 57313662476, 0.8904258972693225, -4.4464114854934245, -4.039534019066576], [0.0347071210494326, 0.9004880753372629, -4.248585770232865, -3.58423248544 63916], [0.021605227577572477, 0.8943150909046184, -4.328424648687163, -3.5 839070388968546], [0.021683511106837167, 0.9034240276625937, -4.54775177137 7432, -3.440918095477595], [0.03322865974533827, 0.9008746011565867, -4.553 133441621834, -3.466410539722705], [0.018355342030016215, 0.900420127213006 2, -4.446929708129446, -3.489941232523649], [0.03952675392736103, 0.8944939 394202345, -4.295968626321516, -3.9682997793674484] … [-0.018308345970896 887, 0.9075915642394172, -4.244850995607026, -3.8183972223508356], [0.02397 7966618637225, 0.9033982100274358, -4.682009858124881, -4.130776769243571], [0.006053823492632723, 0.9016594271485923, -4.648761361800593, -4.09561065 1511632], [-0.0009984301486144955, 0.9038644564356278, -4.886933978188365, -3.957121056789614], [0.025679314791936005, 0.8980582765353018, -4.11190879 6546711, -4.450285527414088], [-0.0033397147921370544, 0.904030671898013, - 4.4167897781716245, -4.13116303901033], [0.008914616385147334, 0.9110178280 428008, -4.520342699658801, -4.358074211143724], [0.023182297279354246, 0.9 084706564135949, -4.3224829916269485, -4.087340744348779], [0.0019520130270 944756, 0.8989223143391836, -4.355738864490417, -3.367061839215601], [0.007 602016195529995, 0.9114132534436983, -4.168772418076348, -4.832014758074155 5]], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStatist icsNUTS(42.28214409741327, 3, turning at positions -1:6, 0.9813386313433596 , 7, DynamicHMC.Directions(0x08188c46)), DynamicHMC.TreeStatisticsNUTS(40.8 0415496451355, 4, turning at positions -8:7, 0.9675644396007395, 15, Dynami cHMC.Directions(0x49400b37)), DynamicHMC.TreeStatisticsNUTS(40.772594490197 065, 4, turning at positions -11:4, 0.9531103192044497, 15, DynamicHMC.Dire ctions(0xc04c5fa4)), DynamicHMC.TreeStatisticsNUTS(43.18347374947412, 3, tu rning at positions -5:2, 0.9956826957247348, 7, DynamicHMC.Directions(0xcbd cb1aa)), DynamicHMC.TreeStatisticsNUTS(39.92899633038914, 4, turning at pos itions -5:10, 0.9262163950309324, 15, DynamicHMC.Directions(0x0c7925ea)), D ynamicHMC.TreeStatisticsNUTS(41.24724918645606, 3, turning at positions -2: 5, 0.9995537298041496, 7, DynamicHMC.Directions(0x49001a25)), DynamicHMC.Tr eeStatisticsNUTS(40.44845093702149, 4, turning at positions -10:5, 0.973760 4402183441, 15, DynamicHMC.Directions(0x101c6275)), DynamicHMC.TreeStatisti csNUTS(38.66845426921367, 4, turning at positions -8:7, 0.9881861509984657, 15, DynamicHMC.Directions(0x223e06d7)), DynamicHMC.TreeStatisticsNUTS(40.0 631012706834, 4, turning at positions -8:7, 0.9983542266718054, 15, Dynamic HMC.Directions(0x1800b3b7)), DynamicHMC.TreeStatisticsNUTS(38.4609384364364 2, 3, turning at positions 11:14, 0.878868885771469, 15, DynamicHMC.Directi ons(0xe4d0995e)) … DynamicHMC.TreeStatisticsNUTS(40.04000603814477, 3, tu rning at positions 8:15, 0.977407960856393, 15, DynamicHMC.Directions(0xa25 8909f)), DynamicHMC.TreeStatisticsNUTS(39.898642118091374, 4, turning at po sitions -11:4, 0.9788967484449674, 15, DynamicHMC.Directions(0x7c3fe144)), DynamicHMC.TreeStatisticsNUTS(44.14415914183852, 3, turning at positions 8: 11, 0.988638008761104, 15, DynamicHMC.Directions(0x0130d70b)), DynamicHMC.T reeStatisticsNUTS(42.444755421968786, 3, turning at positions -2:5, 0.97515 67545391205, 7, DynamicHMC.Directions(0xce6bb27d)), DynamicHMC.TreeStatisti csNUTS(41.40685422842442, 4, turning at positions -3:12, 0.8960975417191821 , 15, DynamicHMC.Directions(0x240bc5cc)), DynamicHMC.TreeStatisticsNUTS(41. 69663283884622, 4, turning at positions -1:14, 0.9927514648611225, 15, Dyna micHMC.Directions(0x4354545e)), DynamicHMC.TreeStatisticsNUTS(43.3779139470 5351, 3, turning at positions -7:-10, 0.9813828108964997, 11, DynamicHMC.Di rections(0x76c297f1)), DynamicHMC.TreeStatisticsNUTS(42.40531501999056, 4, turning at positions 0:15, 0.8265210493735828, 15, DynamicHMC.Directions(0x 513dd6ef)), DynamicHMC.TreeStatisticsNUTS(40.06145429010553, 4, turning at positions -1:14, 0.9974912657623586, 15, DynamicHMC.Directions(0xe65c4b3e)) , DynamicHMC.TreeStatisticsNUTS(39.4472505649862, 4, turning at positions - 15:0, 0.9976434785953803, 15, DynamicHMC.Directions(0x02789df0))], κ = Gaus sian kinetic energy (Diagonal), √diag(M⁻¹): [0.02368980581117611, 0.0207946 66330459215, 0.2553796674078944, 0.2469673379826888], ϵ = 0.234673229552259 75)