Bayesian Inference on a Pendulum using DiffEqBayes.jl

Vaibhav Dixit

Set up simple pendulum problem

using DiffEqBayes, OrdinaryDiffEq, RecursiveArrayTools, Distributions, Plots, StatsPlots, BenchmarkTools, TransformVariables, CmdStan, DynamicHMC

Let's define our simple pendulum problem. Here our pendulum has a drag term ω and a length L.

pendulum

We get first order equations by defining the first term as the velocity and the second term as the position, getting:

function pendulum(du,u,p,t)
    ω,L = p
    x,y = u
    du[1] = y
    du[2] = - ω*y -(9.8/L)*sin(x)
end

u0 = [1.0,0.1]
tspan = (0.0,10.0)
prob1 = ODEProblem(pendulum,u0,tspan,[1.0,2.5])
ODEProblem with uType Array{Float64,1} and tType Float64. In-place: true
timespan: (0.0, 10.0)
u0: [1.0, 0.1]

Solve the model and plot

To understand the model and generate data, let's solve and visualize the solution with the known parameters:

sol = solve(prob1,Tsit5())
plot(sol)

It's the pendulum, so you know what it looks like. It's periodic, but since we have not made a small angle assumption it's not exactly sin or cos. Because the true dampening parameter ω is 1, the solution does not decay over time, nor does it increase. The length L determines the period.

Create some dummy data to use for estimation

We now generate some dummy data to use for estimation

t = collect(range(1,stop=10,length=10))
randomized = VectorOfArray([(sol(t[i]) + .01randn(2)) for i in 1:length(t)])
data = convert(Array,randomized)
2×10 Array{Float64,2}:
  0.0569058  -0.375845  0.136638   0.0720362  …  0.00859386  -0.000752689
 -1.20656     0.334965  0.293891  -0.255306      0.0281103   -0.00753411

Let's see what our data looks like on top of the real solution

scatter!(data')

This data captures the non-dampening effect and the true period, making it perfect to attempting a Bayesian inference.

Perform Bayesian Estimation

Now let's fit the pendulum to the data. Since we know our model is correct, this should give us back the parameters that we used to generate the data! Define priors on our parameters. In this case, let's assume we don't have much information, but have a prior belief that ω is between 0.1 and 3.0, while the length of the pendulum L is probably around 3.0:

priors = [Uniform(0.1,3.0), Normal(3.0,1.0)]
2-element Array{Distributions.Distribution{Distributions.Univariate,Distrib
utions.Continuous},1}:
 Distributions.Uniform{Float64}(a=0.1, b=3.0)
 Distributions.Normal{Float64}(μ=3.0, σ=1.0)

Finally let's run the estimation routine from DiffEqBayes.jl with the Turing.jl backend to check if we indeed recover the parameters!

bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=10_000,
                                   syms = [:omega,:L])
Chains MCMC chain (9000×15×1 Array{Float64,3}):

Iterations        = 1:9000
Thinning interval = 1
Chains            = 1
Samples per chain = 9000
parameters        = L, omega, σ[1]
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy
_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, 
nom_step_size, numerical_error, step_size, tree_depth

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat
  
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64
  
                                                                           
  
           L    2.4780    0.2113     0.0022    0.0032   3799.5678    1.0001
  
       omega    1.0897    0.2204     0.0023    0.0042   2820.5866    0.9999
  
        σ[1]    0.1596    0.0382     0.0004    0.0006   4229.9547    1.0000
  

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%  
      Symbol   Float64   Float64   Float64   Float64   Float64  
                                                                
           L    2.0473    2.3485    2.4795    2.6056    2.9082  
       omega    0.7743    0.9501    1.0518    1.1849    1.6446  
        σ[1]    0.1006    0.1332    0.1537    0.1796    0.2518

Notice that while our guesses had the wrong means, the learned parameters converged to the correct means, meaning that it learned good posterior distributions for the parameters. To look at these posterior distributions on the parameters, we can examine the chains:

plot(bayesian_result)

As a diagnostic, we will also check the parameter chains. The chain is the MCMC sampling process. The chain should explore parameter space and converge reasonably well, and we should be taking a lot of samples after it converges (it is these samples that form the posterior distribution!)

plot(bayesian_result, colordim = :parameter)

Notice that after awhile these chains converge to a "fuzzy line", meaning it found the area with the most likelihood and then starts to sample around there, which builds a posterior distribution around the true mean.

DiffEqBayes.jl allows the choice of using Stan.jl, Turing.jl and DynamicHMC.jl for MCMC, you can also use ApproxBayes.jl for Approximate Bayesian computation algorithms. Let's compare the timings across the different MCMC backends. We'll stick with the default arguments and 10,000 samples in each since there is a lot of room for micro-optimization specific to each package and algorithm combinations, you might want to do your own experiments for specific problems to get better understanding of the performance.

@btime bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;syms = [:omega,:L],num_samples=10_000)
24.699 s (102709003 allocations: 6.73 GiB)
Chains MCMC chain (9000×15×1 Array{Float64,3}):

Iterations        = 1:9000
Thinning interval = 1
Chains            = 1
Samples per chain = 9000
parameters        = L, omega, σ[1]
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy
_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, 
nom_step_size, numerical_error, step_size, tree_depth

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat
  
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64
  
                                                                           
  
           L    0.0390    0.0004     0.0000    0.0000   2733.6119    0.9999
  
       omega    2.5574    0.3303     0.0035    0.0076   1793.3140    1.0000
  
        σ[1]    0.2604    0.0506     0.0005    0.0011   2006.2537    1.0023
  

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%  
      Symbol   Float64   Float64   Float64   Float64   Float64  
                                                                
           L    0.0383    0.0387    0.0390    0.0393    0.0399  
       omega    1.8335    2.3299    2.6239    2.8334    2.9846  
        σ[1]    0.1841    0.2243    0.2534    0.2885    0.3758
@btime bayesian_result = stan_inference(prob1,t,data,priors;num_samples=10_000,printsummary=false)
ERROR: MethodError: no method matching iterate(::ModelingToolkit.ODESystem)
Closest candidates are:
  iterate(!Matched::Core.SimpleVector) at essentials.jl:603
  iterate(!Matched::Core.SimpleVector, !Matched::Any) at essentials.jl:603
  iterate(!Matched::ExponentialBackOff) at error.jl:253
  ...
@btime bayesian_result = dynamichmc_inference(prob1,Tsit5(),t,data,priors;num_samples = 10_000)
8.536 s (60240867 allocations: 3.23 GiB)
(posterior = NamedTuple{(:parameters, :σ),Tuple{Array{Float64,1},Array{Floa
t64,1}}}[(parameters = [0.9928357991299962, 2.4814752963445783], σ = [0.012
042227011386932, 0.02277300206307592]), (parameters = [1.0035536037225914, 
2.4842109006609734], σ = [0.010217641905223972, 0.025125875466338137]), (pa
rameters = [1.023922990689575, 2.478168528000084], σ = [0.01138382781197780
4, 0.015334129812579608]), (parameters = [1.0360305764127264, 2.43616698724
24], σ = [0.011720550962413083, 0.01760567441308434]), (parameters = [1.035
3164419985221, 2.4608038757829234], σ = [0.014284421084486371, 0.0277579640
9788424]), (parameters = [1.0218403104607774, 2.44566016081281], σ = [0.013
188307351048846, 0.027766999301680268]), (parameters = [1.0219203068577845,
 2.4680392948696266], σ = [0.010590988601370078, 0.032035260394786595]), (p
arameters = [1.0337868976662015, 2.4617552238656595], σ = [0.01053414448848
5697, 0.031228924706042]), (parameters = [1.0185248367770061, 2.46063667445
5973], σ = [0.011714478681130423, 0.030502664683171674]), (parameters = [1.
0403183311137085, 2.4460976026190155], σ = [0.013623369350103115, 0.0189055
49430232045])  …  (parameters = [0.9818582336469815, 2.4783464016080754], σ
 = [0.014337869925457273, 0.02196297446360589]), (parameters = [1.024267749
5600505, 2.4679755767540854], σ = [0.009260383114489384, 0.0160703909876819
]), (parameters = [1.0060721849156442, 2.4636880317110124], σ = [0.00957345
263379393, 0.01664557853105612]), (parameters = [0.9990020681169247, 2.4691
265297952634], σ = [0.007544518709041809, 0.019118074991476188]), (paramete
rs = [1.0260118688811353, 2.4548318757910828], σ = [0.01637648529482634, 0.
011675232895375701]), (parameters = [0.9966658558521294, 2.4695369709129857
], σ = [0.012072926915835953, 0.016064184680234996]), (parameters = [1.0089
544699160784, 2.4868524338100477], σ = [0.010885292643205624, 0.01280301983
665544]), (parameters = [1.0234530952572136, 2.480526054452139], σ = [0.013
266901007592884, 0.01678380669943965]), (parameters = [1.0019539194447724, 
2.4569538599391856], σ = [0.012832954266213337, 0.03449082821202926]), (par
ameters = [1.0076309848809268, 2.4878359928799636], σ = [0.0154712406868315
87, 0.007970446556513428])], chain = [[-0.00718998698882292, 0.908853260865
6742, -4.419335888482715, -3.782179564686333], [0.0035473045915490296, 0.90
99550641148565, -4.583639454173646, -3.6838570687785817], [0.02364131938461
103, 0.9075197905796608, -4.475561543750139, -4.17767422821154], [0.0353966
57313662476, 0.8904258972693225, -4.4464114854934245, -4.039534019066576], 
[0.0347071210494326, 0.9004880753372629, -4.248585770232865, -3.58423248544
63916], [0.021605227577572477, 0.8943150909046184, -4.328424648687163, -3.5
839070388968546], [0.021683511106837167, 0.9034240276625937, -4.54775177137
7432, -3.440918095477595], [0.03322865974533827, 0.9008746011565867, -4.553
133441621834, -3.466410539722705], [0.018355342030016215, 0.900420127213006
2, -4.446929708129446, -3.489941232523649], [0.03952675392736103, 0.8944939
394202345, -4.295968626321516, -3.9682997793674484]  …  [-0.018308345970896
887, 0.9075915642394172, -4.244850995607026, -3.8183972223508356], [0.02397
7966618637225, 0.9033982100274358, -4.682009858124881, -4.130776769243571],
 [0.006053823492632723, 0.9016594271485923, -4.648761361800593, -4.09561065
1511632], [-0.0009984301486144955, 0.9038644564356278, -4.886933978188365, 
-3.957121056789614], [0.025679314791936005, 0.8980582765353018, -4.11190879
6546711, -4.450285527414088], [-0.0033397147921370544, 0.904030671898013, -
4.4167897781716245, -4.13116303901033], [0.008914616385147334, 0.9110178280
428008, -4.520342699658801, -4.358074211143724], [0.023182297279354246, 0.9
084706564135949, -4.3224829916269485, -4.087340744348779], [0.0019520130270
944756, 0.8989223143391836, -4.355738864490417, -3.367061839215601], [0.007
602016195529995, 0.9114132534436983, -4.168772418076348, -4.832014758074155
5]], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStatist
icsNUTS(42.28214409741327, 3, turning at positions -1:6, 0.9813386313433596
, 7, DynamicHMC.Directions(0x08188c46)), DynamicHMC.TreeStatisticsNUTS(40.8
0415496451355, 4, turning at positions -8:7, 0.9675644396007395, 15, Dynami
cHMC.Directions(0x49400b37)), DynamicHMC.TreeStatisticsNUTS(40.772594490197
065, 4, turning at positions -11:4, 0.9531103192044497, 15, DynamicHMC.Dire
ctions(0xc04c5fa4)), DynamicHMC.TreeStatisticsNUTS(43.18347374947412, 3, tu
rning at positions -5:2, 0.9956826957247348, 7, DynamicHMC.Directions(0xcbd
cb1aa)), DynamicHMC.TreeStatisticsNUTS(39.92899633038914, 4, turning at pos
itions -5:10, 0.9262163950309324, 15, DynamicHMC.Directions(0x0c7925ea)), D
ynamicHMC.TreeStatisticsNUTS(41.24724918645606, 3, turning at positions -2:
5, 0.9995537298041496, 7, DynamicHMC.Directions(0x49001a25)), DynamicHMC.Tr
eeStatisticsNUTS(40.44845093702149, 4, turning at positions -10:5, 0.973760
4402183441, 15, DynamicHMC.Directions(0x101c6275)), DynamicHMC.TreeStatisti
csNUTS(38.66845426921367, 4, turning at positions -8:7, 0.9881861509984657,
 15, DynamicHMC.Directions(0x223e06d7)), DynamicHMC.TreeStatisticsNUTS(40.0
631012706834, 4, turning at positions -8:7, 0.9983542266718054, 15, Dynamic
HMC.Directions(0x1800b3b7)), DynamicHMC.TreeStatisticsNUTS(38.4609384364364
2, 3, turning at positions 11:14, 0.878868885771469, 15, DynamicHMC.Directi
ons(0xe4d0995e))  …  DynamicHMC.TreeStatisticsNUTS(40.04000603814477, 3, tu
rning at positions 8:15, 0.977407960856393, 15, DynamicHMC.Directions(0xa25
8909f)), DynamicHMC.TreeStatisticsNUTS(39.898642118091374, 4, turning at po
sitions -11:4, 0.9788967484449674, 15, DynamicHMC.Directions(0x7c3fe144)), 
DynamicHMC.TreeStatisticsNUTS(44.14415914183852, 3, turning at positions 8:
11, 0.988638008761104, 15, DynamicHMC.Directions(0x0130d70b)), DynamicHMC.T
reeStatisticsNUTS(42.444755421968786, 3, turning at positions -2:5, 0.97515
67545391205, 7, DynamicHMC.Directions(0xce6bb27d)), DynamicHMC.TreeStatisti
csNUTS(41.40685422842442, 4, turning at positions -3:12, 0.8960975417191821
, 15, DynamicHMC.Directions(0x240bc5cc)), DynamicHMC.TreeStatisticsNUTS(41.
69663283884622, 4, turning at positions -1:14, 0.9927514648611225, 15, Dyna
micHMC.Directions(0x4354545e)), DynamicHMC.TreeStatisticsNUTS(43.3779139470
5351, 3, turning at positions -7:-10, 0.9813828108964997, 11, DynamicHMC.Di
rections(0x76c297f1)), DynamicHMC.TreeStatisticsNUTS(42.40531501999056, 4, 
turning at positions 0:15, 0.8265210493735828, 15, DynamicHMC.Directions(0x
513dd6ef)), DynamicHMC.TreeStatisticsNUTS(40.06145429010553, 4, turning at 
positions -1:14, 0.9974912657623586, 15, DynamicHMC.Directions(0xe65c4b3e))
, DynamicHMC.TreeStatisticsNUTS(39.4472505649862, 4, turning at positions -
15:0, 0.9976434785953803, 15, DynamicHMC.Directions(0x02789df0))], κ = Gaus
sian kinetic energy (Diagonal), √diag(M⁻¹): [0.02368980581117611, 0.0207946
66330459215, 0.2553796674078944, 0.2469673379826888], ϵ = 0.234673229552259
75)