Uncertainty Quantified Deep Bayesian Model Discovery

In this tutorial, we show how SciML can combine the differential equation solvers seamlessly with Bayesian estimation libraries like AdvancedHMC.jl and Turing.jl. This enables converting Neural ODEs to Bayesian Neural ODEs, which enables us to estimate the error in the Neural ODE estimation and forecasting. In this tutorial, a working example of the Bayesian Neural ODE: NUTS sampler is shown.

Note

For more details, have a look at this paper: https://arxiv.org/abs/2012.07244

Step 1: Import Libraries

For this example, we will need the following libraries:

# SciML Libraries
using SciMLSensitivity, DifferentialEquations

# ML Tools
using Lux, Zygote

# External Tools
using Random, Plots, AdvancedHMC, MCMCChains, StatsPlots, ComponentArrays

Setup: Get the data from the Spiral ODE example

We will also need data to fit against. As a demonstration, we will generate our data using a simple cubic ODE u' = A*u^3 as follows:

u0 = [2.0; 0.0]
datasize = 40
tspan = (0.0, 1)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×40 Matrix{Float64}:
 2.0  1.97895   1.94728  1.87998  1.74775  …   0.353996   0.53937   0.718119
 0.0  0.403905  0.79233  1.15176  1.45561     -1.54217   -1.52816  -1.50614

We will want to train a neural network to capture the dynamics that fit ode_data.

Step 2: Define the Neural ODE architecture.

Note that this step potentially offers a lot of flexibility in the number of layers/ number of units in each layer. It may not necessarily be true that a 100 units architecture is better at prediction/forecasting than a 50 unit architecture. On the other hand, a complicated architecture can take a huge computational time without increasing performance.

dudt2 = Lux.Chain(x -> x .^ 3,
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2))

rng = Random.default_rng()
p, st = Lux.setup(rng, dudt2)
const _st = st
function neuralodefunc(u, p, t)
    dudt2(u, p, _st)[1]
end
function prob_neuralode(u0, p)
    prob = ODEProblem(neuralodefunc, u0, tspan, p)
    sol = solve(prob, Tsit5(), saveat = tsteps)
end
p = ComponentArray{Float64}(p)
const _p = p
ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.09652294218540192 0.578551709651947; 1.2563356161117554 0.9462159276008606; … ; 0.3870631456375122 -1.6491808891296387; -0.626069188117981 -0.7127158641815186], bias = [-0.6888546943664551, 0.6829109787940979, -0.46708595752716064, -0.3365190625190735, -0.40957212448120117, 0.106929250061512, 0.5386450886726379, 0.2825061082839966, 0.5657248497009277, 0.08967087417840958  …  -0.6365227699279785, -0.6864922642707825, 0.2043423354625702, 0.42219340801239014, -0.18943187594413757, 0.6082881689071655, 0.5443220734596252, 0.02731410786509514, -0.3835137188434601, 0.46147748827934265]), layer_3 = (weight = [0.17369553446769714 0.12321902811527252 … -0.1723068356513977 0.2263287603855133; -0.21093516051769257 0.22415101528167725 … -0.17215533554553986 0.04069499298930168], bias = [0.0005185073823668063, -0.08973566442728043]))

Note that the f64 is required to put the Lux neural network into Float64 precision.

Step 3: Define the loss function for the Neural ODE.

function predict_neuralode(p)
    p = p isa ComponentArray ? p : convert(typeof(_p), p)
    Array(prob_neuralode(u0, p))
end
function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end
loss_neuralode (generic function with 1 method)

Step 4: Now we start integrating the Bayesian estimation workflow as prescribed by the AdvancedHMC interface with the NeuralODE defined above

The AdvancedHMC interface requires us to specify: (a) the Hamiltonian log density and its gradient , (b) the sampler and (c) the step size adaptor function.

For the Hamiltonian log density, we use the loss function. The θ*θ term denotes the use of Gaussian priors.

The user can make several modifications to Step 4. The user can try different acceptance ratios, warmup samples and posterior samples. One can also use the Variational Inference (ADVI) framework, which doesn't work quite as well as NUTS. The SGLD (Stochastic Gradient Langevin Descent) sampler is seen to have a better performance than NUTS. Have a look at https://sebastiancallh.github.io/post/langevin/ for a brief introduction to SGLD.

l(θ) = -sum(abs2, ode_data .- predict_neuralode(θ)) - sum(θ .* θ)
function dldθ(θ)
    x, lambda = Zygote.pullback(l, θ)
    grad = first(lambda(1))
    return x, grad
end

metric = DiagEuclideanMetric(length(p))
h = Hamiltonian(metric, l, dldθ)
Hamiltonian(metric=DiagEuclideanMetric([1.0, 1.0, 1.0, 1.0, 1.0, 1 ...]), kinetic=AdvancedHMC.GaussianKinetic())

We use the NUTS sampler with an acceptance ratio of δ= 0.45 in this example. In addition, we use Nesterov Dual Averaging for the Step Size adaptation.

We sample using 500 warmup samples and 500 posterior samples.

integrator = Leapfrog(find_good_stepsize(h, p))
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.45, integrator))
samples, stats = sample(h, kernel, p, 500, adaptor, 500; progress = true)
(ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2))), bias = 101:150)), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50))), bias = 101:102)))}}}[(layer_1 = Float64[], layer_2 = (weight = [-1.4001032700354052 -0.2973871434946736; 1.0915507386746874 0.5169292338299953; … ; 0.5119995237158079 -0.48003607166328754; -0.9581769035412169 -0.5519909948735638], bias = [0.7080494349592528, -0.8407829922747755, 0.7941974470938022, -0.15601887899184788, 0.6147130511709255, -0.7660980313014095, 0.35436788321863016, -0.2733925487267192, 0.12951603636304374, 0.7398068151603591  …  -1.2078532958622106, -0.026848407955389814, -0.008217212969778169, -0.06281162033577599, 0.7692941437659293, -0.37945070963390054, -0.06252475569765148, 0.06309486822992251, 0.5665651833340091, -0.5989549940924467]), layer_3 = (weight = [0.8714015242749364 -2.389305437054765 … 0.23364220886029582 -0.16181795028725368; -0.3655190233764496 0.3693109010374095 … 0.9609754913066176 -0.3532933971200249], bias = [-2.434284996655914, -0.5640132083496631])), (layer_1 = Float64[], layer_2 = (weight = [-1.4001032700354052 -0.2973871434946736; 1.0915507386746874 0.5169292338299953; … ; 0.5119995237158079 -0.48003607166328754; -0.9581769035412169 -0.5519909948735638], bias = [0.7080494349592528, -0.8407829922747755, 0.7941974470938022, -0.15601887899184788, 0.6147130511709255, -0.7660980313014095, 0.35436788321863016, -0.2733925487267192, 0.12951603636304374, 0.7398068151603591  …  -1.2078532958622106, -0.026848407955389814, -0.008217212969778169, -0.06281162033577599, 0.7692941437659293, -0.37945070963390054, -0.06252475569765148, 0.06309486822992251, 0.5665651833340091, -0.5989549940924467]), layer_3 = (weight = [0.8714015242749364 -2.389305437054765 … 0.23364220886029582 -0.16181795028725368; -0.3655190233764496 0.3693109010374095 … 0.9609754913066176 -0.3532933971200249], bias = [-2.434284996655914, -0.5640132083496631])), (layer_1 = Float64[], layer_2 = (weight = [-1.4001032700354052 -0.2973871434946736; 1.0915507386746874 0.5169292338299953; … ; 0.5119995237158079 -0.48003607166328754; -0.9581769035412169 -0.5519909948735638], bias = [0.7080494349592528, -0.8407829922747755, 0.7941974470938022, -0.15601887899184788, 0.6147130511709255, -0.7660980313014095, 0.35436788321863016, -0.2733925487267192, 0.12951603636304374, 0.7398068151603591  …  -1.2078532958622106, -0.026848407955389814, -0.008217212969778169, -0.06281162033577599, 0.7692941437659293, -0.37945070963390054, -0.06252475569765148, 0.06309486822992251, 0.5665651833340091, -0.5989549940924467]), layer_3 = (weight = [0.8714015242749364 -2.389305437054765 … 0.23364220886029582 -0.16181795028725368; -0.3655190233764496 0.3693109010374095 … 0.9609754913066176 -0.3532933971200249], bias = [-2.434284996655914, -0.5640132083496631])), (layer_1 = Float64[], layer_2 = (weight = [-1.4001032700354052 -0.2973871434946736; 1.0915507386746874 0.5169292338299953; … ; 0.5119995237158079 -0.48003607166328754; -0.9581769035412169 -0.5519909948735638], bias = [0.7080494349592528, -0.8407829922747755, 0.7941974470938022, -0.15601887899184788, 0.6147130511709255, -0.7660980313014095, 0.35436788321863016, -0.2733925487267192, 0.12951603636304374, 0.7398068151603591  …  -1.2078532958622106, -0.026848407955389814, -0.008217212969778169, -0.06281162033577599, 0.7692941437659293, -0.37945070963390054, -0.06252475569765148, 0.06309486822992251, 0.5665651833340091, -0.5989549940924467]), layer_3 = (weight = [0.8714015242749364 -2.389305437054765 … 0.23364220886029582 -0.16181795028725368; -0.3655190233764496 0.3693109010374095 … 0.9609754913066176 -0.3532933971200249], bias = [-2.434284996655914, -0.5640132083496631])), (layer_1 = Float64[], layer_2 = (weight = [-1.3964157004429094 -0.3018063846739364; 1.1155580483901217 0.507439744758325; … ; 0.5225697933795591 -0.5009495757293674; -0.9453021692849313 -0.5534454929303944], bias = [0.7187231084635931, -0.8494786050897002, 0.7878513191774756, -0.13787923460833781, 0.6018641001625618, -0.7675966508639145, 0.34700925463035415, -0.2498689033369138, 0.12951762722399038, 0.7343156854242481  …  -1.2027918120366643, -0.010988247527573752, 0.010302748383758931, -0.06992712680401657, 0.7744545336631585, -0.37624962567944226, -0.07330301366559973, 0.0647350523925628, 0.5577272719246686, -0.5994863436264213]), layer_3 = (weight = [0.8609683150873028 -2.3880419144906133 … 0.2400766024978548 -0.18384725853881884; -0.367789882004297 0.39448353656125534 … 0.968545815092697 -0.34488679052187954], bias = [-2.4521551809930697, -0.5790681547702514])), (layer_1 = Float64[], layer_2 = (weight = [-1.3872568511058694 -0.34834301298802606; 1.0810027426711641 0.5368158273957454; … ; 0.5697754425034977 -0.5232094458375096; -0.9580271407692051 -0.603688236769476], bias = [0.6815167505793064, -0.762486869080101, 0.7415312686568353, -0.18844078373914885, 0.5433058688983454, -0.8016547257053022, 0.37806850497123184, -0.30898493590149795, 0.17029603175414232, 0.7013015476575111  …  -1.1666045218431647, 0.05692465083269836, -0.0013236373202452877, 0.02124141290027326, 0.7115351363594183, -0.35806505662768434, -0.03637171307539317, 0.1356739558869, 0.5981441088862188, -0.6124622472720141]), layer_3 = (weight = [0.9019355918611139 -2.3805894424070653 … 0.167001031449432 -0.22179718463761522; -0.3788491324723034 0.4369651449455554 … 0.948704264641901 -0.35058063617409363], bias = [-2.3988207912768007, -0.5744214363404779])), (layer_1 = Float64[], layer_2 = (weight = [-1.3719857049752453 -0.35164182288406753; 1.0806235459533664 0.4826351230304934; … ; 0.5848436295180338 -0.5122939783007286; -0.9156102839426321 -0.6054024102061609], bias = [0.6675738843265032, -0.7665820634346564, 0.7007089472427959, -0.2578893197612817, 0.5949421567314723, -0.7662208591753258, 0.346124244331834, -0.2823579132955611, 0.1526916441092328, 0.6934436563330288  …  -1.1483493049873552, 0.062022459059167256, 0.028726583989177098, 0.09910748605662148, 0.6827429191183287, -0.369050479213983, 0.009205410053107599, 0.0453456085493566, 0.6108802457588963, -0.6140950513526684]), layer_3 = (weight = [0.9494237608065539 -2.3835644979880732 … 0.1160022983946999 -0.11243918109421913; -0.48457429762628584 0.5039575174743349 … 0.9641963435533083 -0.3616770985493234], bias = [-2.300119670252629, -0.5735669686922739])), (layer_1 = Float64[], layer_2 = (weight = [-1.3663583965636847 -0.33625268451995444; 1.11186685360775 0.5897993501457383; … ; 0.4589916514727386 -0.5537636400458281; -0.9268612625864281 -0.7012452047678448], bias = [0.6632877884118535, -0.7126701833349319, 0.7033654775650594, -0.2565033461053243, 0.6438163103662313, -0.5560461630582241, 0.29534970213232153, -0.3784223233593202, -0.06781980249673508, 0.5070358216211487  …  -1.2711865193279912, -0.049220070428551886, 0.1491038782179358, 0.20468452216619423, 0.7655821289246673, -0.26762908582266437, -0.07977261782578417, -0.0160014093251746, 0.50417104643668, -0.6358808587051901]), layer_3 = (weight = [0.8856350661649015 -2.49473824905363 … -0.07255283636337706 0.007480080743400961; -0.35113889831990136 0.4611380878446887 … 0.7163627705260194 -0.4725237542915562], bias = [-2.31341082561885, -0.5671927379133672])), (layer_1 = Float64[], layer_2 = (weight = [-0.9456374734542597 -0.2718574315174535; 0.28247058257048663 0.6614059927486874; … ; 0.3668659431205669 -0.2812045626961193; -1.0384540087675977 -0.2927080169188243], bias = [0.3428408358609128, 0.2886116142240268, 0.3480160252388513, -0.6757165539919993, -0.26642030020786533, -0.8779012600326171, 0.5020007042378946, -0.15472304064521739, -0.09489245287898863, 0.2827959863937801  …  -1.1933809008722185, 0.6718149426391378, -0.0014232495725988875, -0.3585376731261855, 0.6717376730438666, 0.08631023382802591, 0.15141200371912558, 0.22632733101013278, 0.039444667400438095, -0.06106277215756176]), layer_3 = (weight = [0.7527380019772337 -2.185480952116004 … 0.1630289029272197 -0.006659595119705522; -0.12197835507744473 1.1119597454962877 … 0.5135247010456431 -0.1470132489670528], bias = [-2.445759026416045, -0.6893016908889593])), (layer_1 = Float64[], layer_2 = (weight = [-0.8785573286198312 -0.6224968125933278; 0.09499658406628494 0.22057690825573678; … ; 0.03277806716354419 -0.0011750811168605938; -0.3629884710402435 -0.5898012465035081], bias = [0.37145575188097324, -0.07743384830875379, 0.40381973911897406, -0.7934745422078472, -0.5948963682237715, -0.4093648046379707, 0.12351750176626744, 0.4965572504054038, -0.07136070200727547, 0.8465428713558066  …  -0.6788280515708323, -0.015814674377018133, -0.49476352102819937, -1.1509034304597663, -0.011370336961620259, -0.2294377822140819, 0.19407141665269645, 0.4821832922632037, -0.28295005310568194, -0.5285099136051498]), layer_3 = (weight = [0.8205897435293315 -1.7417676691719433 … -0.35378129294681626 0.44928601876245755; 0.4306592349746637 1.2256689747073997 … 0.36377140686539366 0.22460075870245758], bias = [-2.0762421205243253, -0.6432780522953483]))  …  (layer_1 = Float64[], layer_2 = (weight = [1.1087290997904247 -0.6713804469559685; -0.8943711152347861 1.1969205861650862; … ; -0.5778266908504059 -0.025522286783068462; 0.27219466244962465 0.30246835848691067], bias = [1.2732928757558688, -0.05757231919949876, 0.34357401506391433, 0.3349162921819124, 0.361295105595381, -0.24739661735138493, 0.14574055913514336, -0.282044885765435, 0.5101182572242211, -0.7001333575605856  …  0.5547434163640936, -0.0943173548320961, 0.5833503903559548, -0.8417820774758735, 0.8475006455566843, -0.8706563726336749, -1.0672868356818976, -1.1898590679253407, 0.022840177906990902, -0.28765780004125163]), layer_3 = (weight = [1.2341946832958732 0.22398140141632128 … -0.03463895600857836 -0.5220531122896762; 1.3465713905014427 -0.7429736365609443 … -0.031737931062638766 0.8669653518156643], bias = [-0.5389249594904201, -0.0021582810691507803])), (layer_1 = Float64[], layer_2 = (weight = [1.2449601016333018 -0.6694999195473561; -1.4520476994900033 0.3168478811427452; … ; -0.4039063528346956 -0.004977191795776524; 0.5493012707949466 0.7060955781387817], bias = [0.753947204523669, -0.1367601868515595, 0.12615582131394495, -0.010400482275480336, 0.7439872961553601, 0.339710842680598, 0.5729073026314646, -0.3220417206199206, 0.20136582223235636, -0.9488656736190769  …  0.4297746788645874, -0.18861461206861654, 0.8408820943738646, -0.20848059157128623, 0.9236432476820179, -0.6669896186616936, -0.4465093672331823, -1.1428440950078576, -0.06959127039729637, 0.14510759946903853]), layer_3 = (weight = [0.7149119031253333 0.08236257123525559 … 0.2561516843544732 -0.42208638414419464; 1.2438549252857363 -0.8009181037276623 … -0.22001190578471014 0.4488297338677757], bias = [-0.7156375740272098, -0.008385277281015405])), (layer_1 = Float64[], layer_2 = (weight = [1.1660844688289085 -0.41350145186609505; -1.064234341154997 -0.021755792308548006; … ; -0.42252373913578756 0.2735431109412456; 0.13723321487770837 1.1035773301115408], bias = [1.2222965460511739, 0.26364495707707053, -0.39218428594709326, 0.36396043793125654, 1.0333765878672396, -0.15182042480646307, 0.061991395242117615, -0.4725467651053163, 0.6360109489763961, -0.6858845111452955  …  0.8353974440039363, -0.19625852972885308, 0.3839487348122864, -0.14598770668745828, 1.011824950675673, -0.5405154772752028, 0.1616098251270805, -0.37060578851105525, -0.6748413782411647, 0.7379624636993003]), layer_3 = (weight = [-0.0752367750974105 0.04877211592351729 … 0.13816624495327653 -0.5011785455986343; 0.6065654272997625 -1.4815103472006705 … -0.5532315463730557 -0.23690903839666516], bias = [-0.17472900433311317, -0.29304608466934906])), (layer_1 = Float64[], layer_2 = (weight = [-0.5327312395666868 -0.37028752132131704; 0.7732923371804933 0.8307348464138217; … ; -0.7528168540192602 0.019962088959533688; 0.49223433688935125 -0.8225523627879162], bias = [-0.8547698407221413, 0.1453508026330559, -0.16419929070101855, -0.0021838165741844776, 0.38578870101796653, -0.24243024346113992, -0.7810036433420714, 0.7304984010345054, -0.6756352435082688, -0.3380144018766418  …  -0.49123874433258014, -1.4975371396606512, -0.8125106470655257, 0.7192938980929726, -0.3607273609940833, -0.09961130985274291, 0.19680576052090193, -0.1305511753718545, -1.3574614180382951, -0.5385094903681918]), layer_3 = (weight = [1.4891419964871269 -0.8744984498880023 … -0.3934864380445791 -0.28732017065539966; 0.11432440336710557 1.0442766588902161 … -0.45701143522933163 1.0801951067430082], bias = [-0.7297917448674088, -0.5259135866184069])), (layer_1 = Float64[], layer_2 = (weight = [-0.5327312395666868 -0.37028752132131704; 0.7732923371804933 0.8307348464138217; … ; -0.7528168540192602 0.019962088959533688; 0.49223433688935125 -0.8225523627879162], bias = [-0.8547698407221413, 0.1453508026330559, -0.16419929070101855, -0.0021838165741844776, 0.38578870101796653, -0.24243024346113992, -0.7810036433420714, 0.7304984010345054, -0.6756352435082688, -0.3380144018766418  …  -0.49123874433258014, -1.4975371396606512, -0.8125106470655257, 0.7192938980929726, -0.3607273609940833, -0.09961130985274291, 0.19680576052090193, -0.1305511753718545, -1.3574614180382951, -0.5385094903681918]), layer_3 = (weight = [1.4891419964871269 -0.8744984498880023 … -0.3934864380445791 -0.28732017065539966; 0.11432440336710557 1.0442766588902161 … -0.45701143522933163 1.0801951067430082], bias = [-0.7297917448674088, -0.5259135866184069])), (layer_1 = Float64[], layer_2 = (weight = [0.5825400985538316 -0.20031630488718774; 1.5839169737856884 1.107185758483559; … ; -0.2550991675409028 0.8453597904037263; 0.2486588829599543 0.5652261916827185], bias = [0.3415070009214961, 0.9685923879777519, 0.30017351254152913, -1.1230745789545544, 0.8466760743823966, 0.45094840416308907, 0.19755463528499376, -0.35684059772899057, -1.0090353858029888, 0.5581486609655127  …  1.2501743643884886, -0.8490646044313701, 0.23607158498154796, -0.11965522506353092, -0.07853881444245481, -0.9512389580242901, 0.19442802755348787, -0.16803135333557245, -1.3804414292308815, -0.23310810813666616]), layer_3 = (weight = [1.1839226089380845 -1.0489084157429804 … -0.4761417184319567 -0.14361507605005186; -0.10339052711617813 0.5430426242477865 … -0.20598626245938637 -0.2882879231308455], bias = [-1.7282273113686948, -0.9555073701816343])), (layer_1 = Float64[], layer_2 = (weight = [0.6724509024232211 -0.2463982298901447; 1.536359798376716 1.0339189621590676; … ; -0.18341675096279345 0.8720404411703955; 0.35924826310676133 0.6176504182850829], bias = [0.22747235639006202, 1.0480765528642828, 0.30660473266438726, -1.1892772638002218, 0.7472516733208475, 0.5746958749341892, 0.19648376234877166, -0.3851077130016033, -0.7618793186849866, 0.49867238870266983  …  1.2148235049164473, -0.8621046387634073, 0.13708343059830266, -0.05853432684431888, -0.13686215241213218, -1.0572327115833786, 0.29347527514226934, -0.13845292005818116, -1.4065720305103748, -0.12565471093401612]), layer_3 = (weight = [1.1478975397572027 -1.084892199633087 … -0.5501802261273424 -0.1685565844024914; -0.045416180330996474 0.5407005763764687 … -0.301973917837707 -0.25005606277744064], bias = [-1.7194065833423686, -0.9317229545674972])), (layer_1 = Float64[], layer_2 = (weight = [0.8524389743208199 -0.0466017123033191; 1.359624549310647 1.2869312425350854; … ; -0.10955082039326473 0.5868095584337342; 0.7248020214473678 0.5032272217968059], bias = [-0.009797874933444709, 1.258543247490454, 0.39848449331871877, -1.1484577966325678, 1.1953780318210563, 0.9668582401548016, 0.3257431240995512, -0.8708514508179417, -0.6889697846758677, 0.6896738941497687  …  1.4272205911512756, -1.0959192282312034, 0.15451252779195182, -0.0054413731901978395, -0.3262051038692822, -1.018272143078334, 0.0036011938891815573, 0.005988279976586821, -1.4295835992493364, -0.3618570178563907]), layer_3 = (weight = [0.6807582533458794 -1.2224868455824718 … -1.065192507050259 -0.3485117291475877; -0.39499227764003686 0.44299191178533165 … -0.29253137651391337 -0.28142880013964233], bias = [-1.6855910365157745, -0.6785099990893697])), (layer_1 = Float64[], layer_2 = (weight = [0.9027001764113162 -0.11200709093577196; 1.089255837440735 1.1945593756536081; … ; 0.3686936942903192 0.45987868194493486; 0.14328854455861423 0.9427426278097696], bias = [-0.22311688784675945, 0.5633054565988455, 0.42025767619914095, -0.4105065688377232, 1.6704721067035353, 0.9594925898296576, 0.32989330682651186, -1.5678442873656098, -0.733117070206897, 0.5836350088318001  …  1.0268411075555415, -0.6653148755203426, 0.024523068997821106, -0.19024892974271454, -0.6730092532658452, -1.1819418098670662, -0.7833170364812555, 0.049777554400866605, -0.6724836549078484, -0.005776390738169535]), layer_3 = (weight = [1.2160257337377123 -0.8360045854134686 … -1.0997247493193596 0.6343380735680545; -0.34987643568744503 0.5272606271748072 … -0.018057397899920558 0.2587489312566118], bias = [-1.9100799982029864, -0.39070465202455007])), (layer_1 = Float64[], layer_2 = (weight = [-1.288344267547679 -0.8835666364886081; -0.5621185713089007 -0.4705552498583; … ; -0.9120950675345431 -0.0013725178840486626; 0.552371864360481 -0.290745351149021], bias = [-0.0059870887495037, -0.19379292393891906, -0.49534087923548176, -0.11878491915586736, -0.9376732061284943, 0.27752009275390677, 0.8068584196755868, 1.8886994366868282, 1.096675424048033, -0.43327147832588697  …  -0.06305050226888527, -0.0106317431227033, -1.0429513637013275, -0.043847582186284685, 0.5264421086314442, 0.9861080155698445, 0.7853050460053226, -0.1902989248247647, 1.8633960441150428, -0.26211863939935115]), layer_3 = (weight = [-0.4268180058787708 0.24085091420146967 … 0.13895660068384794 -0.6520364886272867; 0.7692156968133411 -0.2401837556889348 … -0.7024455419347919 0.442971358358745], bias = [0.8229004092083765, 0.8840295922599888]))], NamedTuple[(n_steps = 63, is_accept = true, acceptance_rate = 0.8412698412708827, log_density = -342.5815401225382, hamiltonian_energy = 549.5835687921251, hamiltonian_energy_error = -41.45955784539524, max_hamiltonian_energy_error = 352.1755586090028, tree_depth = 6, numerical_error = false, step_size = 0.025, nom_step_size = 0.025, is_adapt = true), (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -342.5815401225382, hamiltonian_energy = 456.0220869548158, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 7.264858179499219e6, tree_depth = 0, numerical_error = true, step_size = 0.5092100632054786, nom_step_size = 0.5092100632054786, is_adapt = true), (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -342.5815401225382, hamiltonian_energy = 481.7561986928671, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 794188.7833389912, tree_depth = 0, numerical_error = true, step_size = 0.21768141747775155, nom_step_size = 0.21768141747775155, is_adapt = true), (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -342.5815401225382, hamiltonian_energy = 457.4957038770139, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 2305.5126108076684, tree_depth = 0, numerical_error = true, step_size = 0.064447500404668, nom_step_size = 0.064447500404668, is_adapt = true), (n_steps = 42, is_accept = true, acceptance_rate = 0.8096762949951882, log_density = -315.54198571156974, hamiltonian_energy = 454.0856129810319, hamiltonian_energy_error = -17.19166827354161, max_hamiltonian_energy_error = 2471.8856897775213, tree_depth = 5, numerical_error = true, step_size = 0.016155037053884734, nom_step_size = 0.016155037053884734, is_adapt = true), (n_steps = 2, is_accept = true, acceptance_rate = 0.22336283709522586, log_density = -311.71096273372694, hamiltonian_energy = 444.37511081552424, hamiltonian_energy_error = 0.8058105770340376, max_hamiltonian_energy_error = 1737.7378260099754, tree_depth = 1, numerical_error = true, step_size = 0.041905931041137615, nom_step_size = 0.041905931041137615, is_adapt = true), (n_steps = 122, is_accept = true, acceptance_rate = 0.05151914758704679, log_density = -307.2310680507194, hamiltonian_energy = 450.57874470325896, hamiltonian_energy_error = -0.3342780723017995, max_hamiltonian_energy_error = 3842.8267956022287, tree_depth = 6, numerical_error = true, step_size = 0.019951589059872155, nom_step_size = 0.019951589059872155, is_adapt = true), (n_steps = 511, is_accept = true, acceptance_rate = 0.15165007889498971, log_density = -288.6804813411166, hamiltonian_energy = 426.9247361294491, hamiltonian_energy_error = 0.14128450304718854, max_hamiltonian_energy_error = 58.80859300974106, tree_depth = 9, numerical_error = false, step_size = 0.005534736898251148, nom_step_size = 0.005534736898251148, is_adapt = true), (n_steps = 1023, is_accept = true, acceptance_rate = 0.773178707825642, log_density = -233.99422410225458, hamiltonian_energy = 409.905387005188, hamiltonian_energy_error = -0.10020068591563813, max_hamiltonian_energy_error = 1.955416328695378, tree_depth = 10, numerical_error = false, step_size = 0.0020888977208360637, nom_step_size = 0.0020888977208360637, is_adapt = true), (n_steps = 511, is_accept = true, acceptance_rate = 0.8218851130959844, log_density = -189.33254936853672, hamiltonian_energy = 334.0354572692004, hamiltonian_energy_error = 0.09656110309538235, max_hamiltonian_energy_error = 0.6444303727288911, tree_depth = 9, numerical_error = false, step_size = 0.005663611024171251, nom_step_size = 0.005663611024171251, is_adapt = true)  …  (n_steps = 23, is_accept = true, acceptance_rate = 0.5531100661498549, log_density = -135.24650421380366, hamiltonian_energy = 255.3586388780224, hamiltonian_energy_error = 0.37770337567306456, max_hamiltonian_energy_error = 1004.0192156246886, tree_depth = 4, numerical_error = true, step_size = 0.05074193323706439, nom_step_size = 0.05074193323706439, is_adapt = true), (n_steps = 35, is_accept = true, acceptance_rate = 0.20930270419557642, log_density = -125.07452277620884, hamiltonian_energy = 266.0449360968802, hamiltonian_energy_error = -0.21097339569541873, max_hamiltonian_energy_error = 1448.2109691926098, tree_depth = 5, numerical_error = true, step_size = 0.06716418884701743, nom_step_size = 0.06716418884701743, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.40202645748344357, log_density = -133.06344535317405, hamiltonian_energy = 247.721567679148, hamiltonian_energy_error = 0.1112715751479243, max_hamiltonian_energy_error = 53.50234383337775, tree_depth = 7, numerical_error = false, step_size = 0.037575292856107965, nom_step_size = 0.037575292856107965, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.891889444246366, log_density = -155.32534623740372, hamiltonian_energy = 277.58463480651767, hamiltonian_energy_error = -0.09895652541263189, max_hamiltonian_energy_error = 6.039634610558835, tree_depth = 7, numerical_error = false, step_size = 0.03415105538753306, nom_step_size = 0.03415105538753306, is_adapt = true), (n_steps = 17, is_accept = true, acceptance_rate = 0.0038884813353158233, log_density = -155.32534623740372, hamiltonian_energy = 262.4596885836213, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1706.457466918605, tree_depth = 4, numerical_error = true, step_size = 0.10354238946449394, nom_step_size = 0.10354238946449394, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.922990427571815, log_density = -153.39850520317, hamiltonian_energy = 292.3272681564164, hamiltonian_energy_error = 0.14165148454878818, max_hamiltonian_energy_error = 0.8809916073947761, tree_depth = 7, numerical_error = false, step_size = 0.03541708743359404, nom_step_size = 0.03541708743359404, is_adapt = true), (n_steps = 10, is_accept = true, acceptance_rate = 0.12188075044853541, log_density = -149.35759743762014, hamiltonian_energy = 257.69954726304775, hamiltonian_energy_error = -0.7955724060462899, max_hamiltonian_energy_error = 10418.604326106857, tree_depth = 3, numerical_error = true, step_size = 0.1139395357440992, nom_step_size = 0.1139395357440992, is_adapt = true), (n_steps = 48, is_accept = true, acceptance_rate = 0.3406245912056902, log_density = -148.36495879050278, hamiltonian_energy = 287.6443342707219, hamiltonian_energy_error = -0.3454099903049723, max_hamiltonian_energy_error = 1226.4541256259145, tree_depth = 5, numerical_error = true, step_size = 0.052498868417754176, nom_step_size = 0.052498868417754176, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.2045386787715595, log_density = -145.79477821656587, hamiltonian_energy = 287.38098996784584, hamiltonian_energy_error = 0.38124093393332714, max_hamiltonian_energy_error = 33.34111848456098, tree_depth = 7, numerical_error = false, step_size = 0.04122459438233162, nom_step_size = 0.04122459438233162, is_adapt = true), (n_steps = 255, is_accept = true, acceptance_rate = 0.5224020888335725, log_density = -142.14504645550545, hamiltonian_energy = 272.8585266872514, hamiltonian_energy_error = 0.683949895466128, max_hamiltonian_energy_error = 15.19733033471391, tree_depth = 8, numerical_error = false, step_size = 0.02351260390251382, nom_step_size = 0.02351260390251382, is_adapt = true)])

Step 5: Plot diagnostics

Now let's make sure the fit is good. This can be done by looking at the chain mixing plot and the autocorrelation plot. First, let's create the chain mixing plot using the plot recipes from ????

samples = hcat(samples...)
samples_reduced = samples[1:5, :]
samples_reshape = reshape(samples_reduced, (500, 5, 1))
Chain_Spiral = Chains(samples_reshape)
plot(Chain_Spiral)
Example block output

Now we check the autocorrelation plot:

autocorplot(Chain_Spiral)
Example block output

As another diagnostic, let's check the result on retrodicted data. To do this, we generate solutions of the Neural ODE on samples of the neural network parameters, and check the results of the predictions against the data. Let's start by looking at the time series:

pl = scatter(tsteps, ode_data[1, :], color = :red, label = "Data: Var1", xlabel = "t",
    title = "Spiral Neural ODE")
scatter!(tsteps, ode_data[2, :], color = :blue, label = "Data: Var2")
for k in 1:300
    resol = predict_neuralode(samples[:, 100:end][:, rand(1:400)])
    plot!(tsteps, resol[1, :], alpha = 0.04, color = :red, label = "")
    plot!(tsteps, resol[2, :], alpha = 0.04, color = :blue, label = "")
end

losses = map(x -> loss_neuralode(x)[1], eachcol(samples))
idx = findmin(losses)[2]
prediction = predict_neuralode(samples[:, idx])
plot!(tsteps, prediction[1, :], color = :black, w = 2, label = "")
plot!(tsteps, prediction[2, :], color = :black, w = 2, label = "Best fit prediction",
    ylims = (-2.5, 3.5))
Example block output

That showed the time series form. We can similarly do a phase-space plot:

pl = scatter(ode_data[1, :], ode_data[2, :], color = :red, label = "Data", xlabel = "Var1",
    ylabel = "Var2", title = "Spiral Neural ODE")
for k in 1:300
    resol = predict_neuralode(samples[:, 100:end][:, rand(1:400)])
    plot!(resol[1, :], resol[2, :], alpha = 0.04, color = :red, label = "")
end
plot!(prediction[1, :], prediction[2, :], color = :black, w = 2,
    label = "Best fit prediction", ylims = (-2.5, 3))
Example block output