using DiffEqBayes, OrdinaryDiffEq, RecursiveArrayTools, Distributions, Plots, StatsPlots, BenchmarkTools, TransformVariables, CmdStan, DynamicHMC
Let's define our simple pendulum problem. Here our pendulum has a drag term ω
and a length L
.
We get first order equations by defining the first term as the velocity and the second term as the position, getting:
function pendulum(du,u,p,t) ω,L = p x,y = u du[1] = y du[2] = - ω*y -(9.8/L)*sin(x) end u0 = [1.0,0.1] tspan = (0.0,10.0) prob1 = ODEProblem(pendulum,u0,tspan,[1.0,2.5])
ODEProblem with uType Array{Float64,1} and tType Float64. In-place: true timespan: (0.0, 10.0) u0: [1.0, 0.1]
To understand the model and generate data, let's solve and visualize the solution with the known parameters:
sol = solve(prob1,Tsit5()) plot(sol)
It's the pendulum, so you know what it looks like. It's periodic, but since we have not made a small angle assumption it's not exactly sin
or cos
. Because the true dampening parameter ω
is 1, the solution does not decay over time, nor does it increase. The length L
determines the period.
We now generate some dummy data to use for estimation
t = collect(range(1,stop=10,length=10)) randomized = VectorOfArray([(sol(t[i]) + .01randn(2)) for i in 1:length(t)]) data = convert(Array,randomized)
2×10 Array{Float64,2}: 0.0703023 -0.370584 0.120411 … -0.0200411 0.00329105 0.023678 -1.21217 0.327857 0.310005 -0.0259792 0.0127349 0.0122178
Let's see what our data looks like on top of the real solution
scatter!(data')
This data captures the non-dampening effect and the true period, making it perfect to attempting a Bayesian inference.
Now let's fit the pendulum to the data. Since we know our model is correct, this should give us back the parameters that we used to generate the data! Define priors on our parameters. In this case, let's assume we don't have much information, but have a prior belief that ω is between 0.1 and 3.0, while the length of the pendulum L is probably around 3.0:
priors = [Uniform(0.1,3.0), Normal(3.0,1.0)]
2-element Array{Distribution{Univariate,Continuous},1}: Uniform{Float64}(a=0.1, b=3.0) Normal{Float64}(μ=3.0, σ=1.0)
Finally let's run the estimation routine from DiffEqBayes.jl with the Turing.jl backend to check if we indeed recover the parameters!
bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=10_000, syms = [:omega,:L])
Object of type Chains, with data of type 9000×15×1 Array{Float64,3} Iterations = 1:9000 Thinning interval = 1 Chains = 1 Samples per chain = 9000 internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy _error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth parameters = L, omega, σ[1] 2-element Array{MCMCChains.ChainDataFrame,1} Summary Statistics parameters mean std naive_se mcse ess r_hat ────────── ────── ────── ──────── ────── ───────── ────── L 2.5139 0.2168 0.0023 0.0031 4364.8427 1.0003 omega 1.0897 0.2144 0.0023 0.0031 3537.6842 1.0007 σ[1] 0.1597 0.0378 0.0004 0.0006 3753.7469 1.0009 Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% ────────── ────── ────── ────── ────── ────── L 2.0700 2.3851 2.5138 2.6458 2.9444 omega 0.7737 0.9492 1.0545 1.1896 1.6195 σ[1] 0.1024 0.1327 0.1538 0.1807 0.2473
Notice that while our guesses had the wrong means, the learned parameters converged to the correct means, meaning that it learned good posterior distributions for the parameters. To look at these posterior distributions on the parameters, we can examine the chains:
plot(bayesian_result)
As a diagnostic, we will also check the parameter chains. The chain is the MCMC sampling process. The chain should explore parameter space and converge reasonably well, and we should be taking a lot of samples after it converges (it is these samples that form the posterior distribution!)
plot(bayesian_result, colordim = :parameter)
Notice that after awhile these chains converge to a "fuzzy line", meaning it found the area with the most likelihood and then starts to sample around there, which builds a posterior distribution around the true mean.
DiffEqBayes.jl allows the choice of using Stan.jl, Turing.jl and DynamicHMC.jl for MCMC, you can also use ApproxBayes.jl for Approximate Bayesian computation algorithms. Let's compare the timings across the different MCMC backends. We'll stick with the default arguments and 10,000 samples in each since there is a lot of room for micro-optimization specific to each package and algorithm combinations, you might want to do your own experiments for specific problems to get better understanding of the performance.
@btime bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;syms = [:omega,:L],num_samples=10_000)
7.008 s (43359449 allocations: 3.01 GiB) Object of type Chains, with data of type 9000×15×1 Array{Float64,3} Iterations = 1:9000 Thinning interval = 1 Chains = 1 Samples per chain = 9000 internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy _error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth parameters = L, omega, σ[1] 2-element Array{MCMCChains.ChainDataFrame,1} Summary Statistics parameters mean std naive_se mcse ess r_hat ────────── ────── ────── ──────── ────── ───────── ────── L 2.5198 0.2173 0.0023 0.0028 4326.0420 1.0001 omega 1.0837 0.2105 0.0022 0.0029 4913.9382 1.0000 σ[1] 0.1586 0.0374 0.0004 0.0006 4066.8603 1.0002 Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% ────────── ────── ────── ────── ────── ────── L 2.0949 2.3891 2.5165 2.6464 2.9681 omega 0.7726 0.9458 1.0503 1.1838 1.5935 σ[1] 0.1020 0.1319 0.1528 0.1791 0.2483
@btime bayesian_result = stan_inference(prob1,t,data,priors;num_samples=10_000,printsummary=false)
File /Users/vaibhav/tmp/parameter_estimation_model.stan will be updated. File /Users/vaibhav/tmp/parameter_estimation_model.stan will be updated. File /Users/vaibhav/tmp/parameter_estimation_model.stan will be updated. File /Users/vaibhav/tmp/parameter_estimation_model.stan will be updated. 46.536 s (1071406 allocations: 43.14 MiB) DiffEqBayes.StanModel{Stanmodel,Int64,Array{Float64,3},Array{String,1}}( n ame = "parameter_estimation_model" nchains = 1 num_samples = 10000 num_warmup = 1000 thin = 1 monitors = String[] model_file = "parameter_estimation_model.stan" data_file = "parameter_estimation_model_1.data.R" output = Output() file = "parameter_estimation_model_samples_1.csv" diagnostics_file = "" refresh = 100 pdir = "/Users/vaibhav" tmpdir = "/Users/vaibhav/tmp" output_format = :array method = Sample() num_samples = 10000 num_warmup = 1000 save_warmup = false thin = 1 algorithm = HMC() engine = NUTS() max_depth = 10 metric = CmdStan.diag_e stepsize = 1.0 stepsize_jitter = 1.0 adapt = Adapt() gamma = 0.05 delta = 0.8 kappa = 0.75 t0 = 10.0 init_buffer = 75 term_buffer = 50 window = 25 , 0, [10.5547 0.955816 … 1.02989 2.84037; 6.29168 0.858056 … 1.43967 1.9161 1; … ; 11.0201 0.98966 … 0.959069 2.69565; 10.5695 0.903599 … 1.1711 2.3816 7], ["lp__", "accept_stat__", "stepsize__", "treedepth__", "n_leapfrog__", "divergent__", "energy__", "sigma1.1", "sigma1.2", "theta1", "theta2", "the ta.1", "theta.2"])
@btime bayesian_result = dynamichmc_inference(prob1,Tsit5(),t,data,priors;num_samples = 10_000)
8.268 s (44242911 allocations: 3.75 GiB) (posterior = NamedTuple{(:parameters, :σ),Tuple{Array{Float64,1},Array{Floa t64,1}}}[(parameters = [1.0172158777722564, 2.50704809535446], σ = [0.00996 8205451886447, 0.0075205215583774715]), (parameters = [1.0097965987523527, 2.5100373880788345], σ = [0.011196387188434489, 0.006931132404964102]), (pa rameters = [1.0020901399972686, 2.5185928284022414], σ = [0.010616758286372 62, 0.006840530313635442]), (parameters = [1.0061772865233098, 2.5098895118 58886], σ = [0.011063496758070494, 0.00813397329442808]), (parameters = [0. 9908870191723436, 2.525621897690642], σ = [0.011010255392486179, 0.01225135 6960055679]), (parameters = [1.0154297849033573, 2.5057573064930763], σ = [ 0.00917883966015829, 0.010473459456303213]), (parameters = [0.9996604852053 136, 2.5213383512140126], σ = [0.009029267493849469, 0.010066485853233158]) , (parameters = [0.9960510702645878, 2.5221456126266304], σ = [0.0094771612 66722985, 0.00826339949842958]), (parameters = [0.9888276412119258, 2.53935 5763957219], σ = [0.00841954192515621, 0.008054084594712404]), (parameters = [1.0017962159580898, 2.521484149378387], σ = [0.008458295996061942, 0.007 683310673681056]) … (parameters = [1.0011056892260037, 2.496876318230962] , σ = [0.011051372090944728, 0.00902188744454618]), (parameters = [1.005507 924344299, 2.5165336845082495], σ = [0.010539051732110245, 0.00890149927018 0682]), (parameters = [1.0054018685078112, 2.508602699644201], σ = [0.00886 523247595829, 0.009147126116141347]), (parameters = [0.9714602983375155, 2. 5358205719047104], σ = [0.008703356350472913, 0.01024994048505162]), (param eters = [1.0519498971393437, 2.491208315462291], σ = [0.01571531674444325, 0.010988512869588338]), (parameters = [1.0153150963707929, 2.49969285971112 05], σ = [0.013188344681046144, 0.009089815492377761]), (parameters = [1.00 13809858502667, 2.5231976106407084], σ = [0.010370971423114259, 0.010371808 643788514]), (parameters = [1.0193679004802936, 2.5205998145032824], σ = [0 .01280420076868382, 0.009371496129911025]), (parameters = [1.00984965341356 17, 2.5156617118469518], σ = [0.013482123111319048, 0.006602428659544425]), (parameters = [1.0134774395801138, 2.498293426936223], σ = [0.010057285113 832474, 0.006530887137143007])], chain = [[0.017069363736430455, 0.91910600 34175306, -4.60835470600514, -4.890119787258199], [0.009748923198099696, 0. 9202976486816914, -4.492164125179982, -4.971732072929938], [0.0020879586936 22264, 0.9237003441315097, -4.545321555869692, -4.984890019119817], [0.0061 58285299434456, 0.9202387329950934, -4.504104170282617, -4.811705754705336] , [-0.009154758041091933, 0.9264873285165199, -4.508928132105484, -4.402118 5758767505], [0.01531195627344853, 0.9185910068092693, -4.690854481014291, -4.558910892600256], [-0.00033957244288296317, 0.924789852295613, -4.707284 034018612, -4.598543605030089], [-0.003956747346023469, 0.9251099728413704, -4.658870451996148, -4.795919214564258], [-0.011235238369338155, 0.9319104 126220193, -4.777199855395907, -4.8215759131672105], [0.0017946046913724662 , 0.9248476763274516, -4.772607544556725, -4.868704747407462] … [0.001105 0784018843808, 0.9150404779246882, -4.505200687600752, -4.708101715739219], [0.005492811198320356, 0.9228824326039619, -4.552667708406817, -4.72153555 9107712], [0.005387330746672939, 0.9197259047635631, -4.725618115777442, -4 .694315534697553], [-0.02895487735325546, 0.9305172816684047, -4.7440465403 07749, -4.580483379751004], [0.0506454868879508, 0.9127678600264715, -4.153 119453387317, -4.5109048364116955], [0.01519900409147444, 0.916167868211172 6, -4.328421818153761, -4.700600668866987], [0.0013800331663017828, 0.92552 69902702231, -4.568744584839308, -4.5686638607828405], [0.01918272978793803 7, 0.924496894828965, -4.357981976844077, -4.6700825231289835], [0.00980146 1766340804, 0.9225358750475297, -4.306390685060239, -5.020317718918469], [0 .013387426751075711, 0.9156078695512292, -4.599458020132194, -5.03121248930 9849]], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStat isticsNUTS(50.109169641983016, 3, turning at positions 12:15, 0.98589550358 71283, 15, DynamicHMC.Directions(0x83caddbf)), DynamicHMC.TreeStatisticsNUT S(51.77092409163728, 3, turning at positions -11:-14, 0.9650362163896669, 1 5, DynamicHMC.Directions(0xacb70091)), DynamicHMC.TreeStatisticsNUTS(49.315 93523832962, 3, turning at positions -9:-12, 0.9153716628301408, 15, Dynami cHMC.Directions(0xb2b5b953)), DynamicHMC.TreeStatisticsNUTS(51.783954675507 36, 3, turning at positions 8:15, 0.9995379273936363, 15, DynamicHMC.Direct ions(0xc38bc19f)), DynamicHMC.TreeStatisticsNUTS(47.89500758240069, 4, turn ing at positions -11:4, 0.8124042217369696, 15, DynamicHMC.Directions(0x935 4d504)), DynamicHMC.TreeStatisticsNUTS(48.578879628528085, 3, turning at po sitions -11:-14, 0.9859911477937244, 15, DynamicHMC.Directions(0xdcaef251)) , DynamicHMC.TreeStatisticsNUTS(50.199064669895, 3, turning at positions -2 :5, 0.9247503971761548, 7, DynamicHMC.Directions(0x203a57ad)), DynamicHMC.T reeStatisticsNUTS(50.0587953084984, 3, turning at positions -4:-7, 0.899093 242660993, 11, DynamicHMC.Directions(0x7de07a64)), DynamicHMC.TreeStatistic sNUTS(45.17947487214493, 4, turning at positions -3:12, 0.7830015635358215, 15, DynamicHMC.Directions(0xe834117c)), DynamicHMC.TreeStatisticsNUTS(46.8 68960060238614, 2, turning at positions -4:-7, 0.8627989794177997, 7, Dynam icHMC.Directions(0x4b86ecb0)) … DynamicHMC.TreeStatisticsNUTS(49.91259769 310237, 2, turning at positions 1:4, 0.858129610778053, 7, DynamicHMC.Direc tions(0xbe0e95cc)), DynamicHMC.TreeStatisticsNUTS(50.47656789419056, 3, tur ning at positions -9:-12, 0.9839580389453496, 15, DynamicHMC.Directions(0x3 f777223)), DynamicHMC.TreeStatisticsNUTS(51.36092773588299, 4, turning at p ositions 13:28, 0.9309462682122461, 31, DynamicHMC.Directions(0x13104edc)), DynamicHMC.TreeStatisticsNUTS(44.98265127473394, 3, turning at positions - 6:1, 0.8155974911304378, 7, DynamicHMC.Directions(0x98052af9)), DynamicHMC. TreeStatisticsNUTS(43.25247644556322, 4, turning at positions -6:-9, 0.9783 237337623772, 19, DynamicHMC.Directions(0x258a920a)), DynamicHMC.TreeStatis ticsNUTS(42.17364962066691, 4, turning at positions -5:10, 0.93912768746176 1, 15, DynamicHMC.Directions(0xf5cf67ba)), DynamicHMC.TreeStatisticsNUTS(47 .98182917827681, 4, turning at positions 0:15, 0.9951356291985248, 15, Dyna micHMC.Directions(0x4f8bec8f)), DynamicHMC.TreeStatisticsNUTS(49.5143587232 8723, 4, turning at positions -11:-26, 0.9037430955720961, 31, DynamicHMC.D irections(0x2e4008e5)), DynamicHMC.TreeStatisticsNUTS(49.94898271826714, 4, turning at positions -3:12, 0.9792263716348508, 15, DynamicHMC.Directions( 0xfa68937c)), DynamicHMC.TreeStatisticsNUTS(48.396379904074365, 3, turning at positions -6:-9, 0.8284990356624209, 15, DynamicHMC.Directions(0xde1b03b 6))], κ = Gaussian kinetic energy (LinearAlgebra.Diagonal), √diag(M⁻¹): [0. 02176385942823913, 0.01930503258940134, 0.23594811171252025, 0.322963729780 2619], ϵ = 0.16980415394949835)