Uncertainty Quantified Deep Bayesian Model Discovery
In this tutorial, we show how SciML can combine the differential equation solvers seamlessly with Bayesian estimation libraries like AdvancedHMC.jl and Turing.jl. This enables converting Neural ODEs to Bayesian Neural ODEs, which enables us to estimate the error in the Neural ODE estimation and forecasting. In this tutorial, a working example of the Bayesian Neural ODE: NUTS sampler is shown.
Step 1: Import Libraries
For this example, we will need the following libraries:
# SciML Libraries
import SciMLSensitivity as SMS
import DifferentialEquations as DE
# ML Tools
import Lux
import Zygote
# External Tools
import Random
import Plots
import AdvancedHMC
import MCMCChains
import StatsPlots
import ComponentArrays
Setup: Get the data from the Spiral ODE example
We will also need data to fit against. As a demonstration, we will generate our data using a simple cubic ODE u' = A*u^3
as follows:
u0 = [2.0; 0.0]
datasize = 40
tspan = (0.0, 1)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u .^ 3)'true_A)'
end
prob_trueode = DE.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(DE.solve(prob_trueode, DE.Tsit5(), saveat = tsteps))
2×40 Matrix{Float64}:
2.0 1.97895 1.94728 1.87998 1.74775 … 0.353996 0.53937 0.718119
0.0 0.403905 0.79233 1.15176 1.45561 -1.54217 -1.52816 -1.50614
We will want to train a neural network to capture the dynamics that fit ode_data
.
Step 2: Define the Neural ODE architecture.
Note that this step potentially offers a lot of flexibility in the number of layers/ number of units in each layer. It may not necessarily be true that a 100 units architecture is better at prediction/forecasting than a 50 unit architecture. On the other hand, a complicated architecture can take a huge computational time without increasing performance.
dudt2 = Lux.Chain(x -> x .^ 3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
rng = Random.default_rng()
p, st = Lux.setup(rng, dudt2)
const _st = st
function neuralodefunc(u, p, t)
dudt2(u, p, _st)[1]
end
function prob_neuralode(u0, p)
prob = DE.ODEProblem(neuralodefunc, u0, tspan, p)
sol = DE.solve(prob, DE.Tsit5(), saveat = tsteps)
end
p = ComponentArrays.ComponentArray{Float64}(p)
const _p = p
ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [0.43389758467674255 1.3316842317581177; -0.07050459831953049 1.0818043947219849; … ; 1.6888415813446045 0.23415780067443848; -0.24656350910663605 0.9033612608909607], bias = [-0.393450528383255, -0.32592958211898804, -0.2671773135662079, 0.3475736081600189, -0.692634642124176, 0.04984201863408089, 0.3450500965118408, 0.039619892835617065, 0.37911510467529297, 0.4979235529899597 … -0.6054325699806213, -0.21339572966098785, -0.1932043582201004, -0.694136917591095, -0.22318889200687408, -0.4025947153568268, 0.5187579989433289, 0.36455145478248596, -0.3977013826370239, 0.27677851915359497]), layer_3 = (weight = [-0.15822584927082062 0.012607942335307598 … 0.20132644474506378 -0.21170659363269806; -0.13463136553764343 0.017602255567908287 … -0.2327677309513092 -0.08211147040128708], bias = [0.12374095618724823, -0.07740290462970734]))
Note that the f64
is required to put the Lux neural network into Float64 precision.
Step 3: Define the loss function for the Neural ODE.
function predict_neuralode(p)
p = p isa ComponentArrays.ComponentArray ? p : convert(typeof(_p), p)
Array(prob_neuralode(u0, p))
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
loss_neuralode (generic function with 1 method)
Step 4: Now we start integrating the Bayesian estimation workflow as prescribed by the AdvancedHMC interface with the NeuralODE defined above
The AdvancedHMC interface requires us to specify: (a) the Hamiltonian log density and its gradient , (b) the sampler and (c) the step size adaptor function.
For the Hamiltonian log density, we use the loss function. The θ*θ term denotes the use of Gaussian priors.
The user can make several modifications to Step 4. The user can try different acceptance ratios, warmup samples and posterior samples. One can also use the Variational Inference (ADVI) framework, which doesn't work quite as well as NUTS. The SGLD (Stochastic Gradient Langevin Descent) sampler is seen to have a better performance than NUTS. Have a look at https://sebastiancallh.github.io/post/langevin/ for a brief introduction to SGLD.
l(θ) = -sum(abs2, ode_data .- predict_neuralode(θ)) - sum(θ .* θ)
function dldθ(θ)
x, lambda = Zygote.pullback(l, θ)
grad = first(lambda(1))
return x, grad
end
metric = AdvancedHMC.DiagEuclideanMetric(oneunit.(p))
h = AdvancedHMC.Hamiltonian(metric, l, dldθ)
Hamiltonian(metric=DiagEuclideanMetric((layer_1 = Float64[], layer ...]), kinetic=AdvancedHMC.GaussianKinetic())
We use the NUTS sampler with an acceptance ratio of δ= 0.45 in this example. In addition, we use Nesterov Dual Averaging for the Step Size adaptation.
We sample using 500 warmup samples and 500 posterior samples.
integrator = AdvancedHMC.Leapfrog(AdvancedHMC.find_good_stepsize(h, p))
kernel = AdvancedHMC.HMCKernel(AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS}(integrator, AdvancedHMC.GeneralisedNoUTurn()))
adaptor = AdvancedHMC.StanHMCAdaptor(AdvancedHMC.MassMatrixAdaptor(metric), AdvancedHMC.StepSizeAdaptor(0.45, integrator))
samples, stats = AdvancedHMC.sample(h, kernel, p, 500, adaptor, 500; progress = true)
(ComponentArrays.ComponentVector{Float64, Vector{Float64}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2))), bias = ViewAxis(101:150, Shaped1DAxis((50,))))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50))), bias = ViewAxis(101:102, Shaped1DAxis((2,))))))}}}[(layer_1 = Float64[], layer_2 = (weight = [0.427173829766288 1.3792771679814837; -0.08776309413985653 1.066771928263368; … ; 1.6798174910015107 0.21401169813791077; -0.2123662031168901 0.9358820904728091], bias = [-0.4088603396391527, -0.3252799599604407, -0.25737776115858735, 0.3790966257746509, -0.6812259952059689, 0.05989124287158152, 0.30265446753055664, -0.003612659857746933, 0.4028595219359329, 0.44370951238309714 … -0.6230411485169602, -0.21717421801810752, -0.22006017274088963, -0.6750313230845147, -0.20020458295563226, -0.4105086517549659, 0.4976395302807471, 0.37932444151856326, -0.41065448088305206, 0.28067139861341284]), layer_3 = (weight = [-0.16918573568810993 0.11800727157120111 … 0.12450556456027426 -0.14266405891562056; -0.16787019758428198 0.04701113502513235 … -0.19969121865529812 -0.07238111376781065], bias = [0.08175067535455613, -0.09713525542260891])), (layer_1 = Float64[], layer_2 = (weight = [0.427173829766288 1.3792771679814837; -0.08776309413985653 1.066771928263368; … ; 1.6798174910015107 0.21401169813791077; -0.2123662031168901 0.9358820904728091], bias = [-0.4088603396391527, -0.3252799599604407, -0.25737776115858735, 0.3790966257746509, -0.6812259952059689, 0.05989124287158152, 0.30265446753055664, -0.003612659857746933, 0.4028595219359329, 0.44370951238309714 … -0.6230411485169602, -0.21717421801810752, -0.22006017274088963, -0.6750313230845147, -0.20020458295563226, -0.4105086517549659, 0.4976395302807471, 0.37932444151856326, -0.41065448088305206, 0.28067139861341284]), layer_3 = (weight = [-0.16918573568810993 0.11800727157120111 … 0.12450556456027426 -0.14266405891562056; -0.16787019758428198 0.04701113502513235 … -0.19969121865529812 -0.07238111376781065], bias = [0.08175067535455613, -0.09713525542260891])), (layer_1 = Float64[], layer_2 = (weight = [0.427173829766288 1.3792771679814837; -0.08776309413985653 1.066771928263368; … ; 1.6798174910015107 0.21401169813791077; -0.2123662031168901 0.9358820904728091], bias = [-0.4088603396391527, -0.3252799599604407, -0.25737776115858735, 0.3790966257746509, -0.6812259952059689, 0.05989124287158152, 0.30265446753055664, -0.003612659857746933, 0.4028595219359329, 0.44370951238309714 … -0.6230411485169602, -0.21717421801810752, -0.22006017274088963, -0.6750313230845147, -0.20020458295563226, -0.4105086517549659, 0.4976395302807471, 0.37932444151856326, -0.41065448088305206, 0.28067139861341284]), layer_3 = (weight = [-0.16918573568810993 0.11800727157120111 … 0.12450556456027426 -0.14266405891562056; -0.16787019758428198 0.04701113502513235 … -0.19969121865529812 -0.07238111376781065], bias = [0.08175067535455613, -0.09713525542260891])), (layer_1 = Float64[], layer_2 = (weight = [0.427173829766288 1.3792771679814837; -0.08776309413985653 1.066771928263368; … ; 1.6798174910015107 0.21401169813791077; -0.2123662031168901 0.9358820904728091], bias = [-0.4088603396391527, -0.3252799599604407, -0.25737776115858735, 0.3790966257746509, -0.6812259952059689, 0.05989124287158152, 0.30265446753055664, -0.003612659857746933, 0.4028595219359329, 0.44370951238309714 … -0.6230411485169602, -0.21717421801810752, -0.22006017274088963, -0.6750313230845147, -0.20020458295563226, -0.4105086517549659, 0.4976395302807471, 0.37932444151856326, -0.41065448088305206, 0.28067139861341284]), layer_3 = (weight = [-0.16918573568810993 0.11800727157120111 … 0.12450556456027426 -0.14266405891562056; -0.16787019758428198 0.04701113502513235 … -0.19969121865529812 -0.07238111376781065], bias = [0.08175067535455613, -0.09713525542260891])), (layer_1 = Float64[], layer_2 = (weight = [0.39358901446572164 1.3996455707088318; -0.08395526900211071 1.0629119323191834; … ; 1.7113659704286937 0.1928831335159025; -0.20995325060751388 0.9410995199702636], bias = [-0.401854551202423, -0.28013357287484036, -0.2212506447677537, 0.38099789193860345, -0.6872973423554779, 0.047469345728194196, 0.3124568426292561, 0.015516986219374088, 0.41074447556976124, 0.44547062946441757 … -0.6158436736006321, -0.19451109983635786, -0.20772956147877725, -0.6637792837971376, -0.1688977659109928, -0.4097409335376256, 0.48211560868426456, 0.43147829582143926, -0.4216327734957721, 0.31367349655249854]), layer_3 = (weight = [-0.1595855695412712 0.11378673406956324 … 0.10782175147326847 -0.10384996020550172; -0.18905711947010673 0.062216166475601964 … -0.13579891999198662 -0.07249993505009229], bias = [0.04999734851613623, -0.09401172813932028])), (layer_1 = Float64[], layer_2 = (weight = [-0.9095291533161005 -0.46417571998585516; -0.07359992622248612 -1.034150171001813; … ; 0.16141877349313452 -0.4725872186528092; -0.35060599291894123 -0.4028862490272396], bias = [0.8936103608052294, 0.6659240578076007, 1.0889292972878246, -0.7057620174808203, 1.4278215512292607, 0.4786152742648189, -1.077033758511834, 1.2343130603066765, 0.5227470236647941, -0.011042753649304902 … 1.3284768299417518, 0.5491769992300811, 0.5368024721535913, -0.5327172907723895, 0.4823308148642905, 1.4515965232613424, 0.017180642083700376, -0.3032226702717094, -0.2258189679390164, -0.7845083997920306]), layer_3 = (weight = [-0.37593629469785705 0.26239966062602976 … -0.316187467548225 -0.4923427457848889; 0.2718683049945575 -0.7684282299456138 … 1.048847788712642 -0.1922519050319664], bias = [-1.4331617943943957, 2.5535886894166797])), (layer_1 = Float64[], layer_2 = (weight = [-0.6115391018637254 -0.5705483881018637; -0.7494247366350802 -1.514067396003057; … ; -0.2625143897611131 0.13458943033148144; -0.4853648605939182 -0.39437588608559326], bias = [0.9339416736768428, 0.19820290073614583, 0.8086639472332066, -0.7589093921323213, 0.8245023931237194, 0.13245157783153483, -0.9352514133866803, 1.023395041245903, -0.4242887714211837, -0.09083145473221942 … 1.722150749330374, 1.0325775914540574, 0.46747289952151205, -0.5889949986186629, 0.3172035459587394, 1.474703122740969, -0.08964536679828702, -1.0902047361709581, -0.539484516102713, -0.7020787181933835]), layer_3 = (weight = [-0.03936058621727057 0.5978917432290171 … -0.09536003779208373 -0.2696396237305893; 0.23999893438304815 -0.2694826935159658 … 0.8815476082769085 -0.15149606303997815], bias = [-1.3941883227354084, 1.935452408639493])), (layer_1 = Float64[], layer_2 = (weight = [-0.6013556145951089 -0.730116092852165; -0.8402541151895371 -1.6397552204043995; … ; -0.15573160086672944 0.3633406958845708; -0.5745251360197817 -0.2073455792465693], bias = [1.0602273320153106, 0.2278134597843449, 1.0626559880521644, -0.8912242266617724, 0.7654365438128768, 0.13939468606090608, -0.8641688479695966, 1.00552601036, -0.41235173990362967, -0.2041640010705258 … 1.5956456756586899, 1.141326320899348, 0.4792412127784123, -0.37258081625595824, 0.5009084399609047, 1.458120608582031, -0.16837096846919525, -0.9213286540979079, -0.6009535033641894, -0.8739132695591459]), layer_3 = (weight = [-0.11975088091028904 0.7301838507607241 … -0.041472197088139595 -0.38805608334153; 0.13051410346141712 -0.5475540379213878 … 0.8324880689593411 -0.1254063533290494], bias = [-1.3982524423926916, 1.6373334685796506])), (layer_1 = Float64[], layer_2 = (weight = [-0.6013556145951089 -0.730116092852165; -0.8402541151895371 -1.6397552204043995; … ; -0.15573160086672944 0.3633406958845708; -0.5745251360197817 -0.2073455792465693], bias = [1.0602273320153106, 0.2278134597843449, 1.0626559880521644, -0.8912242266617724, 0.7654365438128768, 0.13939468606090608, -0.8641688479695966, 1.00552601036, -0.41235173990362967, -0.2041640010705258 … 1.5956456756586899, 1.141326320899348, 0.4792412127784123, -0.37258081625595824, 0.5009084399609047, 1.458120608582031, -0.16837096846919525, -0.9213286540979079, -0.6009535033641894, -0.8739132695591459]), layer_3 = (weight = [-0.11975088091028904 0.7301838507607241 … -0.041472197088139595 -0.38805608334153; 0.13051410346141712 -0.5475540379213878 … 0.8324880689593411 -0.1254063533290494], bias = [-1.3982524423926916, 1.6373334685796506])), (layer_1 = Float64[], layer_2 = (weight = [-0.07219984243210985 0.18771291378816785; 1.1295089708731008 1.6565984015546242; … ; 1.0257761968213144 -0.47029396226154074; 1.061083997179671 0.12643792830137884], bias = [-1.3415306474552713, -0.18388345290167254, -0.43905604076658356, 1.0429983385269261, -0.3148413816248726, -0.430515148294775, 1.896112129025023, -0.6318106903148188, 0.7589746661229362, 0.3577457954572128 … -1.9744837923546097, -0.6183903735721539, -0.37899931774598505, 0.014211944547861108, -0.35213475091751034, -1.4902534411622932, 0.17599511996422643, 1.2534873201760044, 0.4614254071790276, 0.7744711596725674]), layer_3 = (weight = [1.1241861356079192 -0.9657829949164921 … 0.7740727259943002 -0.5278158327740643; -1.2105401469050274 0.32525940090853067 … -0.7877917242978354 -0.4448922103329264], bias = [-0.7764275600852462, -0.6689573885947331])) … (layer_1 = Float64[], layer_2 = (weight = [-0.023031437423860105 1.2163220484712884; -0.3332645674476599 -0.23456955381106487; … ; 0.05254875031393122 0.9038463233763924; -0.4625022408858875 -0.20731083419343926], bias = [-0.029692672112478943, 0.36834724008301983, -0.05326232319021815, 0.9310496432446494, -0.22052895067698375, -1.6455186546834237, 0.015935464826914214, -0.5218602442609025, 0.37087893035572617, -1.2134441279580337 … -0.1375802227234198, 0.24074729406992235, 0.985223182351209, -0.37278452246016425, 0.0400107830616032, -0.0053163263614186725, 0.7462207868066617, 0.6951052268307852, -0.5750023077602154, -0.19767390590528455]), layer_3 = (weight = [-0.9196061978397022 1.300960986536574 … 0.5172851811527 0.3819365291780406; -0.5620140316334234 0.7540154689851344 … 0.048798212017048016 -0.28585693368652926], bias = [-0.03830407844963399, 0.10022516782504053])), (layer_1 = Float64[], layer_2 = (weight = [-1.01075632615617 -0.2723725122669922; 1.0027890219700422 -0.3320803644653162; … ; 0.7573872025438035 1.1084512412527272; -0.255826466925538 0.03529618545428448], bias = [-0.1456831250009155, -0.32850193751050616, -0.16926416730339994, 0.5964841094929687, -0.03247181118650001, 0.4964604227962867, 1.2542326851147974, -0.5686275610821523, -0.4646468075334581, -0.8283672341861186 … -0.4187817658980115, 0.2111314000335533, 1.320353075584603, 0.878086933527297, 0.7126129603391328, 0.39088859423500805, -1.102704318413263, 0.10899829650505902, -0.39437421036898435, 0.7658203043478923]), layer_3 = (weight = [0.6445766222802979 0.06415418488241972 … -1.0633501457158723 -0.16984145428448677; 0.7311148595661108 0.8769136395237545 … -0.5126771350115225 -0.7452187827697321], bias = [-0.6863408978628471, 0.033461876231188695])), (layer_1 = Float64[], layer_2 = (weight = [-1.01075632615617 -0.2723725122669922; 1.0027890219700422 -0.3320803644653162; … ; 0.7573872025438035 1.1084512412527272; -0.255826466925538 0.03529618545428448], bias = [-0.1456831250009155, -0.32850193751050616, -0.16926416730339994, 0.5964841094929687, -0.03247181118650001, 0.4964604227962867, 1.2542326851147974, -0.5686275610821523, -0.4646468075334581, -0.8283672341861186 … -0.4187817658980115, 0.2111314000335533, 1.320353075584603, 0.878086933527297, 0.7126129603391328, 0.39088859423500805, -1.102704318413263, 0.10899829650505902, -0.39437421036898435, 0.7658203043478923]), layer_3 = (weight = [0.6445766222802979 0.06415418488241972 … -1.0633501457158723 -0.16984145428448677; 0.7311148595661108 0.8769136395237545 … -0.5126771350115225 -0.7452187827697321], bias = [-0.6863408978628471, 0.033461876231188695])), (layer_1 = Float64[], layer_2 = (weight = [-0.5127079321956372 -0.8352350989322678; 0.6444640518783931 -1.5880327769702047; … ; 1.1534354318292421 0.5566664571457036; -0.5470903354695266 0.4710856648687896], bias = [-0.834269679190506, 0.5508709531240564, 0.8590856285919269, 0.4500421084118419, 1.4220442835528448, 0.568694196444201, 0.3128807201141983, -1.3045127694401915, -0.8691572559320232, -0.2945542773169415 … -0.3806410835774369, 0.3251687336302766, 0.2514057285242023, 0.19585702449861153, 0.271147089131396, 0.22457719351078514, -1.2745008114473657, -0.22215088148290263, 0.8756566754755956, 0.37672278281522903]), layer_3 = (weight = [-0.13179122340812652 -0.22892275616546853 … -0.0824334141073164 0.8503906817405471; 0.4334043424306492 0.3220614668233516 … -0.8260991038217543 -0.7567406336479547], bias = [-0.7593948558680836, 0.19072319594096754])), (layer_1 = Float64[], layer_2 = (weight = [1.2915167224102242 0.45279760393090834; 0.23664513883082483 0.6282368438541482; … ; -1.5582484010489919 -0.7465609332951941; 0.46690423160407113 1.0208239800378396], bias = [0.5810113198722853, -0.5107834569334861, -0.47518693693301317, -0.22897708933125088, -1.564950613395321, -0.11372529079615637, 0.33640742068291174, 1.131025753711345, 0.5184203863297907, 0.17217820823316166 … 0.49510744268772583, -0.4546311964550028, -0.4071654594225475, -0.7518384063699939, -0.3007194909442921, -0.227876234104331, 1.2539993163053613, -0.6675879986567481, -0.9484413769488513, -0.6499532389751712]), layer_3 = (weight = [0.5169813672238953 0.7963477821919095 … 0.5870311515440367 -0.5784408838677154; -0.4802851519635872 0.025145502617097575 … 0.45767386966336615 0.3190829079166814], bias = [-0.35055866945157915, 0.21949676248166583])), (layer_1 = Float64[], layer_2 = (weight = [1.2358656942466812 0.46691055938958026; 0.29143345089294015 0.548294705034165; … ; -1.4726810799935082 -0.7788099536696372; 0.40155085842908844 0.9437303906404999], bias = [0.684337145480306, -0.4749238224129207, -0.4791104848228359, -0.14441288028202476, -1.4882190671576472, -0.03870703268740325, 0.3321989441680495, 1.1947319801532996, 0.48434678149072086, 0.29295096444590973 … 0.5946310986115612, -0.40137718809143413, -0.4351338539849965, -0.766087937201821, -0.28819175200030306, -0.23786393522033547, 1.1120990575320682, -0.7280556600312801, -0.8194512887126679, -0.6087225973021395]), layer_3 = (weight = [0.28191150952369953 0.7915570344880288 … 0.6380936545507158 -0.44549312348259135; -0.48580580262738343 -0.021491042917257223 … 0.37923502270959886 0.35185583326238035], bias = [-0.3925954912426454, 0.14963344488194358])), (layer_1 = Float64[], layer_2 = (weight = [1.2132464441287882 0.725150522559956; 0.0006704551879989293 0.28561933783179677; … ; -1.6223886685964166 -0.7618533823143903; 0.5107239124689543 0.2885440107457353], bias = [0.584044547131589, -0.469780827130745, -0.6644503526108876, -0.28926256702493836, -1.1866417765444994, -0.4207432888449569, 0.5061196708071276, 1.0410380938901316, 0.2946181583713872, 0.25477067863288194 … 0.3383186672801989, -0.6977866049602774, -0.5679527499330915, -0.38493746639801835, -0.3557822830747247, -0.24358983461807285, 1.2176893339588806, -0.6914354821056501, -0.654398011634506, -0.6829887491866141]), layer_3 = (weight = [0.1679572179514391 0.6258494489536655 … -0.12329637127531513 -0.8675411177804775; -0.3656129209359048 -0.06572924496199024 … 0.5583873393673237 0.4645584677081628], bias = [-0.36259664106076267, 0.46016645444697746])), (layer_1 = Float64[], layer_2 = (weight = [0.7574696965456407 0.4320439717875639; 0.3600099262165876 -0.49090407704150835; … ; -1.6442177505583473 0.5167090363232435; -0.4649422066540755 -0.8466607304847996], bias = [0.03761856720951468, -0.03630446634963923, -1.0239506245235321, 0.004785410906443709, -0.5542739267890456, -0.30149404743064856, 0.8037674541416484, 0.5839677273928237, 0.15296503396138597, 0.9116921072460711 … 0.38516262187956823, -0.13152904094444579, 0.3014130211086183, -0.7557789748967936, 0.05061202501800298, -0.17826967776487096, 1.0150139461760384, 0.1738160442334161, -0.4069191530483537, -1.1351736515542141]), layer_3 = (weight = [0.4549065909521956 0.20880842866939106 … -0.7350552953224001 0.11850972155865185; -0.4366270801908569 0.26631377092942476 … 0.097049888639519 0.32192968593520194], bias = [-0.20769721220233617, 0.40605356380284224])), (layer_1 = Float64[], layer_2 = (weight = [0.6906630039634334 0.44211985965697437; 0.347793875760054 -0.44539631063359864; … ; -1.6388484197743949 0.5427135679937822; -0.17587391115154813 -0.639573020616543], bias = [-0.25931859014156955, -0.052400613745266354, -0.9362003772315467, 0.029086075505847513, -0.2957962480816376, -0.2520688334555298, 0.826313845892549, 0.5238262009564789, 0.1943358060986329, 0.8864311574844767 … 0.5244900445993684, -0.21140496648356819, 0.26500520796683424, -0.8294167691536578, 0.289757193267674, -0.17352761369599778, 1.0604537546020893, 0.16732948680123913, -0.45029245527632367, -0.8715304904019394]), layer_3 = (weight = [0.6567499105352167 0.2475713022093829 … -0.8553376527065043 0.09346782493781432; -0.47318796627430465 0.29182728038088257 … 0.052498245865137315 0.4121979255695534], bias = [-0.2806650322762205, 0.5267344220520119])), (layer_1 = Float64[], layer_2 = (weight = [0.5655179995866073 0.7849574048366803; 0.366940649057286 -0.1830136124421201; … ; -1.4631207370677979 0.761985917038189; -0.09212128967429715 -0.252732333930057], bias = [-0.5690368447136418, 0.08131158273924718, -0.7399091778537471, -0.09084461331828332, -0.330667407417535, -0.11613195241052988, 0.7481142684822085, 0.2964303479544989, 0.2843270318663907, 0.49835754808681665 … 0.2726695038443589, -0.2827063113790738, 0.2539328459558115, -0.627828434672077, 0.24123828284824678, -0.22165609414023305, 0.7055224269054731, 0.030392146698825152, -0.6374416067294902, -1.1540633491590009]), layer_3 = (weight = [0.9259855228575679 -0.007219121520539138 … -1.1243624682797895 0.12214368397025294; -0.3321394745520289 0.30089802216485284 … -0.02783662179132989 0.06821819566851481], bias = [-0.5365205082077921, 0.4392592414888343]))], NamedTuple[(n_steps = 7, is_accept = true, acceptance_rate = 1.0, log_density = -403.8361174890106, hamiltonian_energy = 716.2907100110124, hamiltonian_energy_error = -69.01995025759959, max_hamiltonian_energy_error = -69.01995025759959, tree_depth = 3, numerical_error = false, step_size = 0.025, nom_step_size = 0.025, is_adapt = true), (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -403.8361174890106, hamiltonian_energy = 536.4336463808372, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 6.256006973791662e6, tree_depth = 0, numerical_error = true, step_size = 0.6795704571147614, nom_step_size = 0.6795704571147614, is_adapt = true), (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -403.8361174890106, hamiltonian_energy = 559.1847193004318, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 47200.589925482716, tree_depth = 0, numerical_error = true, step_size = 0.3164493440113639, nom_step_size = 0.3164493440113639, is_adapt = true), (n_steps = 20, is_accept = true, acceptance_rate = 1.0457086680641369e-56, log_density = -403.8361174890106, hamiltonian_energy = 556.1443234301258, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 5277.497386847336, tree_depth = 4, numerical_error = true, step_size = 0.09837809577101383, nom_step_size = 0.09837809577101383, is_adapt = true), (n_steps = 63, is_accept = true, acceptance_rate = 0.3333333333333333, log_density = -343.6701142152775, hamiltonian_energy = 518.3005614589002, hamiltonian_energy_error = -12.464823027657644, max_hamiltonian_energy_error = 230.9404370893094, tree_depth = 6, numerical_error = false, step_size = 0.025425348076056706, nom_step_size = 0.025425348076056706, is_adapt = true), (n_steps = 255, is_accept = true, acceptance_rate = 0.46487089049415964, log_density = -308.5159067056693, hamiltonian_energy = 469.5796663961531, hamiltonian_energy_error = -0.7922306582153738, max_hamiltonian_energy_error = 143.81687823956872, tree_depth = 8, numerical_error = false, step_size = 0.016256436869068507, nom_step_size = 0.016256436869068507, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.4795110248669744, log_density = -244.5059835476079, hamiltonian_energy = 446.5193174981357, hamiltonian_energy_error = 0.752655883606451, max_hamiltonian_energy_error = 2.789337973454508, tree_depth = 7, numerical_error = false, step_size = 0.01580437085372024, nom_step_size = 0.01580437085372024, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.7827226779813397, log_density = -237.96292061930492, hamiltonian_energy = 382.19446271203884, hamiltonian_energy_error = -0.7735104048276185, max_hamiltonian_energy_error = 1.7404542270022603, tree_depth = 7, numerical_error = false, step_size = 0.01654931989162797, nom_step_size = 0.01654931989162797, is_adapt = true), (n_steps = 2, is_accept = true, acceptance_rate = 1.0804458027921262e-14, log_density = -237.96292061930492, hamiltonian_energy = 366.5804067572813, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1085.479503285815, tree_depth = 1, numerical_error = true, step_size = 0.04586860581035696, nom_step_size = 0.04586860581035696, is_adapt = true), (n_steps = 255, is_accept = true, acceptance_rate = 0.9953481016555956, log_density = -230.62366823200807, hamiltonian_energy = 367.6894988815941, hamiltonian_energy_error = -0.9588213682720834, max_hamiltonian_energy_error = -0.9588213682720834, tree_depth = 8, numerical_error = false, step_size = 0.010984970877668606, nom_step_size = 0.010984970877668606, is_adapt = true) … (n_steps = 127, is_accept = true, acceptance_rate = 0.17882849400973047, log_density = -130.57411047273888, hamiltonian_energy = 248.13180103003225, hamiltonian_energy_error = 0.6727419783209712, max_hamiltonian_energy_error = 77.13731908593701, tree_depth = 7, numerical_error = false, step_size = 0.03867210569108757, nom_step_size = 0.03867210569108757, is_adapt = true), (n_steps = 255, is_accept = true, acceptance_rate = 0.948753294926125, log_density = -127.38720868197366, hamiltonian_energy = 248.88753100473312, hamiltonian_energy_error = -0.07517957753188398, max_hamiltonian_energy_error = -0.4385844444006466, tree_depth = 8, numerical_error = false, step_size = 0.019851486443565686, nom_step_size = 0.019851486443565686, is_adapt = true), (n_steps = 24, is_accept = true, acceptance_rate = 0.07496255971223038, log_density = -127.38720868197366, hamiltonian_energy = 245.81753590330732, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 3329.0080840367473, tree_depth = 4, numerical_error = true, step_size = 0.07012503845005003, nom_step_size = 0.07012503845005003, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.7072223698317894, log_density = -128.14694825150127, hamiltonian_energy = 239.93018835412033, hamiltonian_energy_error = -0.3851279384399504, max_hamiltonian_energy_error = 13.149089988371287, tree_depth = 7, numerical_error = false, step_size = 0.027985394576575794, nom_step_size = 0.027985394576575794, is_adapt = true), (n_steps = 63, is_accept = true, acceptance_rate = 0.7529195366953972, log_density = -128.0864757453481, hamiltonian_energy = 252.8127108134268, hamiltonian_energy_error = -0.3752831111008845, max_hamiltonian_energy_error = 1.696451189447373, tree_depth = 6, numerical_error = false, step_size = 0.05349179850403503, nom_step_size = 0.05349179850403503, is_adapt = true), (n_steps = 12, is_accept = true, acceptance_rate = 0.1312959163082468, log_density = -127.05154125708479, hamiltonian_energy = 261.7556348249114, hamiltonian_energy_error = 0.5562883162183994, max_hamiltonian_energy_error = 7597.833085661845, tree_depth = 3, numerical_error = true, step_size = 0.11324892396881746, nom_step_size = 0.11324892396881746, is_adapt = true), (n_steps = 63, is_accept = true, acceptance_rate = 0.17924022463743453, log_density = -123.06397183514538, hamiltonian_energy = 250.99380654220835, hamiltonian_energy_error = 0.1687738005150834, max_hamiltonian_energy_error = 276.9608530520359, tree_depth = 6, numerical_error = false, step_size = 0.052633550208860735, nom_step_size = 0.052633550208860735, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.5673510102841475, log_density = -116.3467682216669, hamiltonian_energy = 232.5091162206705, hamiltonian_energy_error = 0.27519600207145345, max_hamiltonian_energy_error = 3.54149569277547, tree_depth = 7, numerical_error = false, step_size = 0.027744282700663873, nom_step_size = 0.027744282700663873, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.2724219231294601, log_density = -118.86353173051704, hamiltonian_energy = 262.13247558277783, hamiltonian_energy_error = -0.04366353828851288, max_hamiltonian_energy_error = 71.89731596901942, tree_depth = 7, numerical_error = false, step_size = 0.03728625288372442, nom_step_size = 0.03728625288372442, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.8286444629924854, log_density = -123.0712542643354, hamiltonian_energy = 236.8783407801145, hamiltonian_energy_error = -0.33071655279121615, max_hamiltonian_energy_error = 1.4809066444993277, tree_depth = 7, numerical_error = false, step_size = 0.02478691450903971, nom_step_size = 0.02478691450903971, is_adapt = true)])
Step 5: Plot diagnostics
Now let's make sure the fit is good. This can be done by looking at the chain mixing plot and the autocorrelation plot. First, let's create the chain mixing plot using the plot recipes from ????
samples = hcat(samples...)
samples_reduced = samples[1:5, :]
samples_reshape = reshape(samples_reduced, (500, 5, 1))
Chain_Spiral = MCMCChains.Chains(samples_reshape)
Plots.plot(Chain_Spiral)
Now we check the autocorrelation plot:
MCMCChains.autocorplot(Chain_Spiral)
As another diagnostic, let's check the result on retrodicted data. To do this, we generate solutions of the Neural ODE on samples of the neural network parameters, and check the results of the predictions against the data. Let's start by looking at the time series:
pl = Plots.scatter(tsteps, ode_data[1, :], color = :red, label = "Data: Var1", xlabel = "t",
title = "Spiral Neural ODE")
Plots.scatter!(tsteps, ode_data[2, :], color = :blue, label = "Data: Var2")
for k in 1:300
resol = predict_neuralode(samples[:, 100:end][:, rand(1:400)])
Plots.plot!(tsteps, resol[1, :], alpha = 0.04, color = :red, label = "")
Plots.plot!(tsteps, resol[2, :], alpha = 0.04, color = :blue, label = "")
end
losses = map(x -> loss_neuralode(x)[1], eachcol(samples))
idx = findmin(losses)[2]
prediction = predict_neuralode(samples[:, idx])
Plots.plot!(tsteps, prediction[1, :], color = :black, w = 2, label = "")
Plots.plot!(tsteps, prediction[2, :], color = :black, w = 2, label = "Best fit prediction",
ylims = (-2.5, 3.5))
That showed the time series form. We can similarly do a phase-space plot:
pl = Plots.scatter(ode_data[1, :], ode_data[2, :], color = :red, label = "Data", xlabel = "Var1",
ylabel = "Var2", title = "Spiral Neural ODE")
for k in 1:300
resol = predict_neuralode(samples[:, 100:end][:, rand(1:400)])
Plots.plot!(resol[1, :], resol[2, :], alpha = 0.04, color = :red, label = "")
end
Plots.plot!(prediction[1, :], prediction[2, :], color = :black, w = 2,
label = "Best fit prediction", ylims = (-2.5, 3))