Bayesian Inference on a Pendulum using DiffEqBayes.jl

Vaibhav Dixit

Set up simple pendulum problem

using DiffEqBayes, OrdinaryDiffEq, RecursiveArrayTools, Distributions, Plots, StatsPlots, BenchmarkTools, TransformVariables, CmdStan, DynamicHMC

Let's define our simple pendulum problem. Here our pendulum has a drag term ω and a length L.

pendulum

We get first order equations by defining the first term as the velocity and the second term as the position, getting:

function pendulum(du,u,p,t)
    ω,L = p
    x,y = u
    du[1] = y
    du[2] = - ω*y -(9.8/L)*sin(x)
end

u0 = [1.0,0.1]
tspan = (0.0,10.0)
prob1 = ODEProblem(pendulum,u0,tspan,[1.0,2.5])
ODEProblem with uType Array{Float64,1} and tType Float64. In-place: true
timespan: (0.0, 10.0)
u0: [1.0, 0.1]

Solve the model and plot

To understand the model and generate data, let's solve and visualize the solution with the known parameters:

sol = solve(prob1,Tsit5())
plot(sol)

It's the pendulum, so you know what it looks like. It's periodic, but since we have not made a small angle assumption it's not exactly sin or cos. Because the true dampening parameter ω is 1, the solution does not decay over time, nor does it increase. The length L determines the period.

Create some dummy data to use for estimation

We now generate some dummy data to use for estimation

t = collect(range(1,stop=10,length=10))
randomized = VectorOfArray([(sol(t[i]) + .01randn(2)) for i in 1:length(t)])
data = convert(Array,randomized)
2×10 Array{Float64,2}:
  0.0703023  -0.370584  0.120411  …  -0.0200411  0.00329105  0.023678
 -1.21217     0.327857  0.310005     -0.0259792  0.0127349   0.0122178

Let's see what our data looks like on top of the real solution

scatter!(data')

This data captures the non-dampening effect and the true period, making it perfect to attempting a Bayesian inference.

Perform Bayesian Estimation

Now let's fit the pendulum to the data. Since we know our model is correct, this should give us back the parameters that we used to generate the data! Define priors on our parameters. In this case, let's assume we don't have much information, but have a prior belief that ω is between 0.1 and 3.0, while the length of the pendulum L is probably around 3.0:

priors = [Uniform(0.1,3.0), Normal(3.0,1.0)]
2-element Array{Distribution{Univariate,Continuous},1}:
 Uniform{Float64}(a=0.1, b=3.0)
 Normal{Float64}(μ=3.0, σ=1.0)

Finally let's run the estimation routine from DiffEqBayes.jl with the Turing.jl backend to check if we indeed recover the parameters!

bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=10_000,
                                   syms = [:omega,:L])
Object of type Chains, with data of type 9000×15×1 Array{Float64,3}

Iterations        = 1:9000
Thinning interval = 1
Chains            = 1
Samples per chain = 9000
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy
_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, 
nom_step_size, numerical_error, step_size, tree_depth
parameters        = L, omega, σ[1]

2-element Array{MCMCChains.ChainDataFrame,1}

Summary Statistics
  parameters    mean     std  naive_se    mcse        ess   r_hat
  ──────────  ──────  ──────  ────────  ──────  ─────────  ──────
           L  2.5139  0.2168    0.0023  0.0031  4364.8427  1.0003
       omega  1.0897  0.2144    0.0023  0.0031  3537.6842  1.0007
        σ[1]  0.1597  0.0378    0.0004  0.0006  3753.7469  1.0009

Quantiles
  parameters    2.5%   25.0%   50.0%   75.0%   97.5%
  ──────────  ──────  ──────  ──────  ──────  ──────
           L  2.0700  2.3851  2.5138  2.6458  2.9444
       omega  0.7737  0.9492  1.0545  1.1896  1.6195
        σ[1]  0.1024  0.1327  0.1538  0.1807  0.2473

Notice that while our guesses had the wrong means, the learned parameters converged to the correct means, meaning that it learned good posterior distributions for the parameters. To look at these posterior distributions on the parameters, we can examine the chains:

plot(bayesian_result)

As a diagnostic, we will also check the parameter chains. The chain is the MCMC sampling process. The chain should explore parameter space and converge reasonably well, and we should be taking a lot of samples after it converges (it is these samples that form the posterior distribution!)

plot(bayesian_result, colordim = :parameter)

Notice that after awhile these chains converge to a "fuzzy line", meaning it found the area with the most likelihood and then starts to sample around there, which builds a posterior distribution around the true mean.

DiffEqBayes.jl allows the choice of using Stan.jl, Turing.jl and DynamicHMC.jl for MCMC, you can also use ApproxBayes.jl for Approximate Bayesian computation algorithms. Let's compare the timings across the different MCMC backends. We'll stick with the default arguments and 10,000 samples in each since there is a lot of room for micro-optimization specific to each package and algorithm combinations, you might want to do your own experiments for specific problems to get better understanding of the performance.

@btime bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;syms = [:omega,:L],num_samples=10_000)
7.008 s (43359449 allocations: 3.01 GiB)
Object of type Chains, with data of type 9000×15×1 Array{Float64,3}

Iterations        = 1:9000
Thinning interval = 1
Chains            = 1
Samples per chain = 9000
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy
_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, 
nom_step_size, numerical_error, step_size, tree_depth
parameters        = L, omega, σ[1]

2-element Array{MCMCChains.ChainDataFrame,1}

Summary Statistics
  parameters    mean     std  naive_se    mcse        ess   r_hat
  ──────────  ──────  ──────  ────────  ──────  ─────────  ──────
           L  2.5198  0.2173    0.0023  0.0028  4326.0420  1.0001
       omega  1.0837  0.2105    0.0022  0.0029  4913.9382  1.0000
        σ[1]  0.1586  0.0374    0.0004  0.0006  4066.8603  1.0002

Quantiles
  parameters    2.5%   25.0%   50.0%   75.0%   97.5%
  ──────────  ──────  ──────  ──────  ──────  ──────
           L  2.0949  2.3891  2.5165  2.6464  2.9681
       omega  0.7726  0.9458  1.0503  1.1838  1.5935
        σ[1]  0.1020  0.1319  0.1528  0.1791  0.2483
@btime bayesian_result = stan_inference(prob1,t,data,priors;num_samples=10_000,printsummary=false)
File /Users/vaibhav/tmp/parameter_estimation_model.stan will be updated.


File /Users/vaibhav/tmp/parameter_estimation_model.stan will be updated.


File /Users/vaibhav/tmp/parameter_estimation_model.stan will be updated.


File /Users/vaibhav/tmp/parameter_estimation_model.stan will be updated.

  46.536 s (1071406 allocations: 43.14 MiB)
DiffEqBayes.StanModel{Stanmodel,Int64,Array{Float64,3},Array{String,1}}(  n
ame =                    "parameter_estimation_model"
  nchains =                 1
  num_samples =             10000
  num_warmup =                1000
  thin =                    1
  monitors =                String[]
  model_file =              "parameter_estimation_model.stan"
  data_file =               "parameter_estimation_model_1.data.R"
  output =                  Output()
    file =                    "parameter_estimation_model_samples_1.csv"
    diagnostics_file =        ""
    refresh =                 100
  pdir =                   "/Users/vaibhav"
  tmpdir =                 "/Users/vaibhav/tmp"
  output_format =           :array
  method =                  Sample()
    num_samples =             10000
    num_warmup =              1000
    save_warmup =             false
    thin =                    1
    algorithm =               HMC()
      engine =                  NUTS()
        max_depth =               10
      metric =                  CmdStan.diag_e
      stepsize =                1.0
      stepsize_jitter =         1.0
    adapt =                   Adapt()
      gamma =                   0.05
      delta =                   0.8
      kappa =                   0.75
      t0 =                      10.0
      init_buffer =             75
      term_buffer =             50
      window =                  25
, 0, [10.5547 0.955816 … 1.02989 2.84037; 6.29168 0.858056 … 1.43967 1.9161
1; … ; 11.0201 0.98966 … 0.959069 2.69565; 10.5695 0.903599 … 1.1711 2.3816
7], ["lp__", "accept_stat__", "stepsize__", "treedepth__", "n_leapfrog__", 
"divergent__", "energy__", "sigma1.1", "sigma1.2", "theta1", "theta2", "the
ta.1", "theta.2"])
@btime bayesian_result = dynamichmc_inference(prob1,Tsit5(),t,data,priors;num_samples = 10_000)
8.268 s (44242911 allocations: 3.75 GiB)
(posterior = NamedTuple{(:parameters, :σ),Tuple{Array{Float64,1},Array{Floa
t64,1}}}[(parameters = [1.0172158777722564, 2.50704809535446], σ = [0.00996
8205451886447, 0.0075205215583774715]), (parameters = [1.0097965987523527, 
2.5100373880788345], σ = [0.011196387188434489, 0.006931132404964102]), (pa
rameters = [1.0020901399972686, 2.5185928284022414], σ = [0.010616758286372
62, 0.006840530313635442]), (parameters = [1.0061772865233098, 2.5098895118
58886], σ = [0.011063496758070494, 0.00813397329442808]), (parameters = [0.
9908870191723436, 2.525621897690642], σ = [0.011010255392486179, 0.01225135
6960055679]), (parameters = [1.0154297849033573, 2.5057573064930763], σ = [
0.00917883966015829, 0.010473459456303213]), (parameters = [0.9996604852053
136, 2.5213383512140126], σ = [0.009029267493849469, 0.010066485853233158])
, (parameters = [0.9960510702645878, 2.5221456126266304], σ = [0.0094771612
66722985, 0.00826339949842958]), (parameters = [0.9888276412119258, 2.53935
5763957219], σ = [0.00841954192515621, 0.008054084594712404]), (parameters 
= [1.0017962159580898, 2.521484149378387], σ = [0.008458295996061942, 0.007
683310673681056])  …  (parameters = [1.0011056892260037, 2.496876318230962]
, σ = [0.011051372090944728, 0.00902188744454618]), (parameters = [1.005507
924344299, 2.5165336845082495], σ = [0.010539051732110245, 0.00890149927018
0682]), (parameters = [1.0054018685078112, 2.508602699644201], σ = [0.00886
523247595829, 0.009147126116141347]), (parameters = [0.9714602983375155, 2.
5358205719047104], σ = [0.008703356350472913, 0.01024994048505162]), (param
eters = [1.0519498971393437, 2.491208315462291], σ = [0.01571531674444325, 
0.010988512869588338]), (parameters = [1.0153150963707929, 2.49969285971112
05], σ = [0.013188344681046144, 0.009089815492377761]), (parameters = [1.00
13809858502667, 2.5231976106407084], σ = [0.010370971423114259, 0.010371808
643788514]), (parameters = [1.0193679004802936, 2.5205998145032824], σ = [0
.01280420076868382, 0.009371496129911025]), (parameters = [1.00984965341356
17, 2.5156617118469518], σ = [0.013482123111319048, 0.006602428659544425]),
 (parameters = [1.0134774395801138, 2.498293426936223], σ = [0.010057285113
832474, 0.006530887137143007])], chain = [[0.017069363736430455, 0.91910600
34175306, -4.60835470600514, -4.890119787258199], [0.009748923198099696, 0.
9202976486816914, -4.492164125179982, -4.971732072929938], [0.0020879586936
22264, 0.9237003441315097, -4.545321555869692, -4.984890019119817], [0.0061
58285299434456, 0.9202387329950934, -4.504104170282617, -4.811705754705336]
, [-0.009154758041091933, 0.9264873285165199, -4.508928132105484, -4.402118
5758767505], [0.01531195627344853, 0.9185910068092693, -4.690854481014291, 
-4.558910892600256], [-0.00033957244288296317, 0.924789852295613, -4.707284
034018612, -4.598543605030089], [-0.003956747346023469, 0.9251099728413704,
 -4.658870451996148, -4.795919214564258], [-0.011235238369338155, 0.9319104
126220193, -4.777199855395907, -4.8215759131672105], [0.0017946046913724662
, 0.9248476763274516, -4.772607544556725, -4.868704747407462]  …  [0.001105
0784018843808, 0.9150404779246882, -4.505200687600752, -4.708101715739219],
 [0.005492811198320356, 0.9228824326039619, -4.552667708406817, -4.72153555
9107712], [0.005387330746672939, 0.9197259047635631, -4.725618115777442, -4
.694315534697553], [-0.02895487735325546, 0.9305172816684047, -4.7440465403
07749, -4.580483379751004], [0.0506454868879508, 0.9127678600264715, -4.153
119453387317, -4.5109048364116955], [0.01519900409147444, 0.916167868211172
6, -4.328421818153761, -4.700600668866987], [0.0013800331663017828, 0.92552
69902702231, -4.568744584839308, -4.5686638607828405], [0.01918272978793803
7, 0.924496894828965, -4.357981976844077, -4.6700825231289835], [0.00980146
1766340804, 0.9225358750475297, -4.306390685060239, -5.020317718918469], [0
.013387426751075711, 0.9156078695512292, -4.599458020132194, -5.03121248930
9849]], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStat
isticsNUTS(50.109169641983016, 3, turning at positions 12:15, 0.98589550358
71283, 15, DynamicHMC.Directions(0x83caddbf)), DynamicHMC.TreeStatisticsNUT
S(51.77092409163728, 3, turning at positions -11:-14, 0.9650362163896669, 1
5, DynamicHMC.Directions(0xacb70091)), DynamicHMC.TreeStatisticsNUTS(49.315
93523832962, 3, turning at positions -9:-12, 0.9153716628301408, 15, Dynami
cHMC.Directions(0xb2b5b953)), DynamicHMC.TreeStatisticsNUTS(51.783954675507
36, 3, turning at positions 8:15, 0.9995379273936363, 15, DynamicHMC.Direct
ions(0xc38bc19f)), DynamicHMC.TreeStatisticsNUTS(47.89500758240069, 4, turn
ing at positions -11:4, 0.8124042217369696, 15, DynamicHMC.Directions(0x935
4d504)), DynamicHMC.TreeStatisticsNUTS(48.578879628528085, 3, turning at po
sitions -11:-14, 0.9859911477937244, 15, DynamicHMC.Directions(0xdcaef251))
, DynamicHMC.TreeStatisticsNUTS(50.199064669895, 3, turning at positions -2
:5, 0.9247503971761548, 7, DynamicHMC.Directions(0x203a57ad)), DynamicHMC.T
reeStatisticsNUTS(50.0587953084984, 3, turning at positions -4:-7, 0.899093
242660993, 11, DynamicHMC.Directions(0x7de07a64)), DynamicHMC.TreeStatistic
sNUTS(45.17947487214493, 4, turning at positions -3:12, 0.7830015635358215,
 15, DynamicHMC.Directions(0xe834117c)), DynamicHMC.TreeStatisticsNUTS(46.8
68960060238614, 2, turning at positions -4:-7, 0.8627989794177997, 7, Dynam
icHMC.Directions(0x4b86ecb0))  …  DynamicHMC.TreeStatisticsNUTS(49.91259769
310237, 2, turning at positions 1:4, 0.858129610778053, 7, DynamicHMC.Direc
tions(0xbe0e95cc)), DynamicHMC.TreeStatisticsNUTS(50.47656789419056, 3, tur
ning at positions -9:-12, 0.9839580389453496, 15, DynamicHMC.Directions(0x3
f777223)), DynamicHMC.TreeStatisticsNUTS(51.36092773588299, 4, turning at p
ositions 13:28, 0.9309462682122461, 31, DynamicHMC.Directions(0x13104edc)),
 DynamicHMC.TreeStatisticsNUTS(44.98265127473394, 3, turning at positions -
6:1, 0.8155974911304378, 7, DynamicHMC.Directions(0x98052af9)), DynamicHMC.
TreeStatisticsNUTS(43.25247644556322, 4, turning at positions -6:-9, 0.9783
237337623772, 19, DynamicHMC.Directions(0x258a920a)), DynamicHMC.TreeStatis
ticsNUTS(42.17364962066691, 4, turning at positions -5:10, 0.93912768746176
1, 15, DynamicHMC.Directions(0xf5cf67ba)), DynamicHMC.TreeStatisticsNUTS(47
.98182917827681, 4, turning at positions 0:15, 0.9951356291985248, 15, Dyna
micHMC.Directions(0x4f8bec8f)), DynamicHMC.TreeStatisticsNUTS(49.5143587232
8723, 4, turning at positions -11:-26, 0.9037430955720961, 31, DynamicHMC.D
irections(0x2e4008e5)), DynamicHMC.TreeStatisticsNUTS(49.94898271826714, 4,
 turning at positions -3:12, 0.9792263716348508, 15, DynamicHMC.Directions(
0xfa68937c)), DynamicHMC.TreeStatisticsNUTS(48.396379904074365, 3, turning 
at positions -6:-9, 0.8284990356624209, 15, DynamicHMC.Directions(0xde1b03b
6))], κ = Gaussian kinetic energy (LinearAlgebra.Diagonal), √diag(M⁻¹): [0.
02176385942823913, 0.01930503258940134, 0.23594811171252025, 0.322963729780
2619], ϵ = 0.16980415394949835)