Uncertainty Quantified Deep Bayesian Model Discovery

In this tutorial, we show how SciML can combine the differential equation solvers seamlessly with Bayesian estimation libraries like AdvancedHMC.jl and Turing.jl. This enables converting Neural ODEs to Bayesian Neural ODEs, which enables us to estimate the error in the Neural ODE estimation and forecasting. In this tutorial, a working example of the Bayesian Neural ODE: NUTS sampler is shown.

Note

For more details, have a look at this paper: https://arxiv.org/abs/2012.07244

Step 1: Import Libraries

For this example, we will need the following libraries:

# SciML Libraries
using DiffEqFlux, Flux, DifferentialEquations

# External Tools
using Random, Plots, AdvancedHMC, MCMCChains, StatsPlots, ComponentArrays

Setup: Get the data from the Spiral ODE example

We will also need data to fit against. As a demonstration, we will generate our data using a simple cubic ODE u' = A*u^3 as follows:

u0 = [2.0; 0.0]
datasize = 40
tspan = (0.0, 1)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×40 Matrix{Float64}:
 2.0  1.97895   1.94728  1.87998  1.74775  …   0.353996   0.53937   0.718119
 0.0  0.403905  0.79233  1.15176  1.45561     -1.54217   -1.52816  -1.50614

We will want to train a neural network to capture the dynamics that fit ode_data.

Step 2: Define the Neural ODE architecture.

Note that this step potentially offers a lot of flexibility in the number of layers/ number of units in each layer. It may not necessarily be true that a 100 units architecture is better at prediction/forecasting than a 50 unit architecture. On the other hand, a complicated architecture can take a huge computational time without increasing performance.

dudt2 = Flux.Chain(x -> x .^ 3,
                   Flux.Dense(2, 50, tanh),
                   Flux.Dense(50, 2)) |> f64
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
rng = Random.default_rng()
p = Float64.(prob_neuralode.p)
252-element Vector{Float64}:
  0.2954992353916168
  0.13152720034122467
 -0.2804926335811615
 -0.29425305128097534
  0.020911309868097305
  0.2705550789833069
 -0.033966172486543655
  0.2652934491634369
  0.05473713204264641
 -0.1368948072195053
  ⋮
 -0.25859174132347107
  0.22791129350662231
 -0.12906330823898315
  0.19743387401103973
 -0.1093311607837677
  0.3094402551651001
  0.13257285952568054
  0.0
  0.0

Note that the f64 is required to put the Flux neural network into Float64 precision.

Step 3: Define the loss function for the Neural ODE.

function predict_neuralode(p)
    Array(prob_neuralode(u0, p))
end
function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end
loss_neuralode (generic function with 1 method)

Step 4: Now we start integrating the Bayesian estimation workflow as prescribed by the AdvancedHMC interface with the NeuralODE defined above

The AdvancedHMC interface requires us to specify: (a) the Hamiltonian log density and its gradient , (b) the sampler and (c) the step size adaptor function.

For the Hamiltonian log density, we use the loss function. The θ*θ term denotes the use of Gaussian priors.

The user can make several modifications to Step 4. The user can try different acceptance ratios, warmup samples and posterior samples. One can also use the Variational Inference (ADVI) framework, which doesn't work quite as well as NUTS. The SGLD (Stochastic Gradient Langevin Descent) sampler is seen to have a better performance than NUTS. Have a look at https://sebastiancallh.github.io/post/langevin/ for a brief introduction to SGLD.

l(θ) = -sum(abs2, ode_data .- predict_neuralode(θ)) - sum(θ .* θ)
function dldθ(θ)
    x, lambda = Flux.Zygote.pullback(l, θ)
    grad = first(lambda(1))
    return x, grad
end

metric = DiagEuclideanMetric(length(p))
h = Hamiltonian(metric, l, dldθ)
Hamiltonian(metric=DiagEuclideanMetric([1.0, 1.0, 1.0, 1.0, 1.0, 1 ...]), kinetic=AdvancedHMC.GaussianKinetic())

We use the NUTS sampler with an acceptance ratio of δ= 0.45 in this example. In addition, we use Nesterov Dual Averaging for the Step Size adaptation.

We sample using 500 warmup samples and 500 posterior samples.

integrator = Leapfrog(find_good_stepsize(h, p))
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), StepSizeAdaptor(0.45, integrator))
samples, stats = sample(h, kernel, p, 500, adaptor, 500; progress = true)
([[0.25882908909502755, 0.06140805102900124, -0.2817781084634905, -0.28977703630424095, 0.07109633687993228, 0.2871926957610466, 0.012124577630453395, 0.2827859974731848, -0.014781758442180881, -0.14020872285894814  …  -0.26573924881121136, -0.21812774673632487, 0.19610278782109938, -0.10014590574732928, 0.2132249808034227, -0.07746120368484678, 0.3307137073684151, 0.15555150595243156, -0.05762821658035007, -0.02861361192276676], [0.25882908909502755, 0.06140805102900124, -0.2817781084634905, -0.28977703630424095, 0.07109633687993228, 0.2871926957610466, 0.012124577630453395, 0.2827859974731848, -0.014781758442180881, -0.14020872285894814  …  -0.26573924881121136, -0.21812774673632487, 0.19610278782109938, -0.10014590574732928, 0.2132249808034227, -0.07746120368484678, 0.3307137073684151, 0.15555150595243156, -0.05762821658035007, -0.02861361192276676], [0.25882908909502755, 0.06140805102900124, -0.2817781084634905, -0.28977703630424095, 0.07109633687993228, 0.2871926957610466, 0.012124577630453395, 0.2827859974731848, -0.014781758442180881, -0.14020872285894814  …  -0.26573924881121136, -0.21812774673632487, 0.19610278782109938, -0.10014590574732928, 0.2132249808034227, -0.07746120368484678, 0.3307137073684151, 0.15555150595243156, -0.05762821658035007, -0.02861361192276676], [0.22324936532655648, -0.0185583108785133, -0.41074842108947573, -0.26914669778617106, 0.0720098559533739, 0.29694202519711554, 0.14668652540031074, 0.36967947079801045, -0.06050515167942927, -0.07723979878565539  …  -0.2856919857518257, -0.2396054997882444, 0.1984128230207988, -0.02318036067740588, 0.3049630517146846, -0.058314537069096284, 0.2925346277166615, 0.2254881456797237, -0.14641069625328101, -0.26194476576095677], [0.22324936532655648, -0.0185583108785133, -0.41074842108947573, -0.26914669778617106, 0.0720098559533739, 0.29694202519711554, 0.14668652540031074, 0.36967947079801045, -0.06050515167942927, -0.07723979878565539  …  -0.2856919857518257, -0.2396054997882444, 0.1984128230207988, -0.02318036067740588, 0.3049630517146846, -0.058314537069096284, 0.2925346277166615, 0.2254881456797237, -0.14641069625328101, -0.26194476576095677], [0.13265197200623527, -0.48374205075888366, -0.06909087063320511, -0.16982304235277543, 0.23751825775516383, 0.12980517998704208, 0.47996143807881975, 0.32620988332488643, -0.21346280824366623, 0.18298719085731446  …  -0.35599190495955857, -0.0960945052252048, 0.2205908203755052, -0.06543339521970802, 1.0927318569400088, 0.05532712489617286, 0.7239098036481502, 0.11794710544425024, -0.39985904915350723, -0.36451796014363935], [0.2442427974354491, -0.5805498465673719, 0.25136680052995053, -0.0929917278853432, 0.3515152661166124, 0.03415747531487702, 0.3224504648953168, 0.1968597609385832, -0.20482387594390303, 0.006799184499565479  …  0.05580884766009864, -0.2230754230963608, 0.012344145142197185, -0.02606427467507121, 1.5067871489087963, -0.4300317221422816, 1.0103404899436157, -0.22581567887270884, -0.31982074379675035, -0.6244575861962858], [0.18603693157112705, -0.7938909904647504, 0.30446316159746106, -0.10254151836917341, 0.3289520420810928, 0.3230252956182382, 0.25059707029142886, 0.014045492410421856, -0.2449196605514523, -0.02063643439311479  …  -0.003271827807713512, -0.22229948268156405, -0.12384434815922508, -0.12779499062984578, 1.358939135706979, -0.28671487777165716, 1.0260949837021474, -0.19360304254218116, -0.6167120360267422, -0.5672391675004449], [0.28584041159008594, -0.44503837168683064, 0.2555267287342357, -0.26720601220622575, 0.7588315079793987, 0.5637340023176993, 0.07923560697098711, 0.16721261085712905, -0.734577437047482, -0.09993657257242827  …  0.13227726955828367, -0.1782888227324013, 0.049503822486541726, -0.42226313584448616, 1.6748909263197196, -0.20753706567105237, 0.5535104432478566, -0.34472019066308535, -0.07487886923285181, -0.18563874314885953], [0.28584041159008594, -0.44503837168683064, 0.2555267287342357, -0.26720601220622575, 0.7588315079793987, 0.5637340023176993, 0.07923560697098711, 0.16721261085712905, -0.734577437047482, -0.09993657257242827  …  0.13227726955828367, -0.1782888227324013, 0.049503822486541726, -0.42226313584448616, 1.6748909263197196, -0.20753706567105237, 0.5535104432478566, -0.34472019066308535, -0.07487886923285181, -0.18563874314885953]  …  [-0.9578295619826368, 0.0024223249095419288, 0.6888670959905334, 1.0145819803161975, 0.009956104117369049, -1.4279880950572845, 0.33907567530850535, -0.14626403243398423, 0.48995515283986457, -0.7782044346380649  …  -0.7006984280677557, -1.0227918841183994, -0.6991281499664309, 1.4620296266287132, -0.5190858708543094, 1.241871714429819, 0.3413720588965307, -0.45127152622805605, -0.12131313186656908, -0.27623591136652353], [-0.9265508917184263, 0.0430634408798625, 0.6876943251996213, 1.0117256328344277, 0.004891939224042265, -1.4206885455269302, 0.30301242811848916, -0.02328984756229087, 0.4913643304539004, -0.7769402136118446  …  -0.6920752968376442, -0.9847090670256893, -0.7164981980607968, 1.4905321394555249, -0.5771882968570621, 1.236411966108691, 0.25630129404922875, -0.42416773311706313, -0.1743606758318446, -0.26826992071583655], [-1.065229039727526, -0.13238466980815206, 0.6247023672692025, 1.059429321149351, -0.1365404195397373, -1.6398957291097191, 0.09520355282597417, -0.1923525892549918, 0.5476407650315073, -0.9409249442792038  …  -0.8654806906741916, -1.1924717431361518, -0.514595159001815, 1.4065420230956425, -0.8918343255056677, 0.9451808109983606, 0.41202935009454766, -0.3697367305499138, -0.22939547947773525, -0.30460602633808914], [-1.2172535913898168, -0.4118328675915314, 0.371221776596737, 1.3755036593326528, 0.18974087440367238, -1.3186250657019374, -0.68678848841307, -0.4502995859295989, 0.543797164600687, -1.2563050637653097  …  0.06485151910240372, -2.238024359073047, -0.9865627359435151, 0.6259068028176925, -0.8804439108225887, -0.3049328922310113, 0.5444417635822533, -0.17933058246689706, -0.8373959925353001, -0.4180613649730928], [0.6627617763707754, 1.195735402944075, 0.6662700022566863, 0.5292820621494664, 0.6203353249336211, -0.4271773982298574, 0.22148823589421548, 0.5454524324789278, -0.22553921843959981, -1.1599119410679117  …  -0.1799360662082566, 1.5912490998890758, -0.2005444167685915, -0.9966653614525672, 1.5800442864573945, -0.30015204850516963, -0.6651174530084867, 0.2948793124741895, -0.23695821976183654, -0.00853777919630144], [0.7333871158442158, 1.2240234107591619, 0.7009076038661268, 0.4119615289659066, 0.6502532762572992, -0.44844971274810264, 0.19494537662662884, 0.5034981540911779, -0.24665984697198728, -1.2625148026793844  …  -0.31316499590958985, 1.5612327674887359, -0.17557910886871964, -1.0367706767319365, 1.6446523249437965, -0.35264838576374946, -0.6338788717995151, 0.3254313233446264, -0.25173405084284417, 0.049231015398894906], [0.7902159296573344, 1.2073055339495564, 1.3596523736071726, -0.044117091292883145, 1.0525406328741054, 0.6108802020631342, 0.5156029242064595, -0.11429058134114715, -0.3534407539337638, -0.8039511146657707  …  -0.5956883404938683, 0.14782506723038408, -0.3857915705004825, 0.5180922233224157, -0.3258339349280035, -0.5026600645491833, 0.5588115068230348, 0.2609098053992195, -0.5956506550174704, -0.02994861902591732], [0.7902159296573344, 1.2073055339495564, 1.3596523736071726, -0.044117091292883145, 1.0525406328741054, 0.6108802020631342, 0.5156029242064595, -0.11429058134114715, -0.3534407539337638, -0.8039511146657707  …  -0.5956883404938683, 0.14782506723038408, -0.3857915705004825, 0.5180922233224157, -0.3258339349280035, -0.5026600645491833, 0.5588115068230348, 0.2609098053992195, -0.5956506550174704, -0.02994861902591732], [0.4816309816987265, 1.5146050831213633, 1.2871322949699278, -0.015095942104576241, 1.2511805235174895, 0.9562704432798774, 0.26772610078881715, 0.2770073547698264, -0.18397518391427498, -1.1491615468055032  …  -0.5496187739813273, 0.3127173791397819, -0.8392773843271722, 0.6844433744895067, -0.13074875112658804, -0.16404168416193943, 0.21127577223820626, -0.2337329081479188, -0.05446160547901582, -0.23623779335977102], [0.07728642551093051, 1.7274110349559033, 1.2719277030616434, -0.7429910337528022, 0.903734810708252, 0.5625167049004514, 0.42573722820796805, -0.011719009997038492, 0.6810062067980489, -0.6963025096878122  …  -0.8808342389092261, 0.6875246890211796, -1.661879331091328, 0.4605289992931528, 0.75170740574604, -0.41871554694362145, 0.28099754168913516, -0.23170606477515088, -0.2630657415773724, -0.5565744107226738]], NamedTuple[(n_steps = 27, is_accept = true, acceptance_rate = 0.9259259259259259, log_density = -223.87658038566735, hamiltonian_energy = 455.91586793085946, hamiltonian_energy_error = -33.70819401002393, max_hamiltonian_energy_error = 1011.8232653751784, tree_depth = 4, numerical_error = true, step_size = 0.025, nom_step_size = 0.025, is_adapt = true), (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -223.87658038566735, hamiltonian_energy = 351.3990978080052, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 81215.43898349591, tree_depth = 0, numerical_error = true, step_size = 0.5939414723923541, nom_step_size = 0.5939414723923541, is_adapt = true), (n_steps = 7, is_accept = true, acceptance_rate = 3.5664065903798893e-87, log_density = -223.87658038566735, hamiltonian_energy = 369.63930020100514, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1072.2822085323544, tree_depth = 2, numerical_error = true, step_size = 0.26575342662280543, nom_step_size = 0.26575342662280543, is_adapt = true), (n_steps = 6, is_accept = true, acceptance_rate = 0.8087612782750928, log_density = -178.67135289358123, hamiltonian_energy = 322.24151274232963, hamiltonian_energy_error = -11.745629018908971, max_hamiltonian_energy_error = 3397.4715407586614, tree_depth = 2, numerical_error = true, step_size = 0.08075609732411505, nom_step_size = 0.08075609732411505, is_adapt = true), (n_steps = 2, is_accept = true, acceptance_rate = 0.01102674812912989, log_density = -178.67135289358123, hamiltonian_energy = 308.0096215786001, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 2800.4199328813397, tree_depth = 1, numerical_error = true, step_size = 0.20744223800646458, nom_step_size = 0.20744223800646458, is_adapt = true), (n_steps = 42, is_accept = true, acceptance_rate = 0.26936364062765455, log_density = -108.89032019874543, hamiltonian_energy = 296.70085093973995, hamiltonian_energy_error = -6.542075053662359, max_hamiltonian_energy_error = 1250.7842239862562, tree_depth = 5, numerical_error = true, step_size = 0.05558816993158587, nom_step_size = 0.05558816993158587, is_adapt = true), (n_steps = 21, is_accept = true, acceptance_rate = 0.9303714574011673, log_density = -86.41354059896764, hamiltonian_energy = 219.05801207343367, hamiltonian_energy_error = -0.37247764098916036, max_hamiltonian_energy_error = 1978.1596715777805, tree_depth = 4, numerical_error = true, step_size = 0.03070184972930456, nom_step_size = 0.03070184972930456, is_adapt = true), (n_steps = 5, is_accept = true, acceptance_rate = 0.2251015593217633, log_density = -78.67914751267244, hamiltonian_energy = 208.3112502986278, hamiltonian_energy_error = -0.12444382399164056, max_hamiltonian_energy_error = 191951.6524423446, tree_depth = 2, numerical_error = true, step_size = 0.13225947206287436, nom_step_size = 0.13225947206287436, is_adapt = true), (n_steps = 8, is_accept = true, acceptance_rate = 0.4925581123798915, log_density = -63.05748828865877, hamiltonian_energy = 200.65507106110294, hamiltonian_energy_error = -1.0171497792809419, max_hamiltonian_energy_error = 1576.9425271088915, tree_depth = 3, numerical_error = true, step_size = 0.06483409944021855, nom_step_size = 0.06483409944021855, is_adapt = true), (n_steps = 11, is_accept = true, acceptance_rate = 0.0007437344170793469, log_density = -63.05748828865877, hamiltonian_energy = 185.18582064573232, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 36631.89685897407, tree_depth = 3, numerical_error = true, step_size = 0.0736776124810418, nom_step_size = 0.0736776124810418, is_adapt = true)  …  (n_steps = 127, is_accept = true, acceptance_rate = 0.8616145771568091, log_density = -126.41564673588732, hamiltonian_energy = 233.1060773154934, hamiltonian_energy_error = 0.3042213073283051, max_hamiltonian_energy_error = -0.6232332751637841, tree_depth = 7, numerical_error = false, step_size = 0.025661601867559315, nom_step_size = 0.025661601867559315, is_adapt = true), (n_steps = 46, is_accept = true, acceptance_rate = 0.04167192603685165, log_density = -126.71797701102658, hamiltonian_energy = 261.08843798047303, hamiltonian_energy_error = -0.10134216897199622, max_hamiltonian_energy_error = 4442.795048876764, tree_depth = 5, numerical_error = true, step_size = 0.07361812647062146, nom_step_size = 0.07361812647062146, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.7070677663786702, log_density = -134.3139308007602, hamiltonian_energy = 246.74977957614055, hamiltonian_energy_error = -0.7016472480690936, max_hamiltonian_energy_error = 1.8870691393361767, tree_depth = 7, numerical_error = false, step_size = 0.026935385919121927, nom_step_size = 0.026935385919121927, is_adapt = true), (n_steps = 43, is_accept = true, acceptance_rate = 0.30572857842794915, log_density = -135.0648215108848, hamiltonian_energy = 260.87005006113, hamiltonian_energy_error = 0.010897950588685035, max_hamiltonian_energy_error = 1051.6563949277197, tree_depth = 5, numerical_error = true, step_size = 0.05188622827622959, nom_step_size = 0.05188622827622959, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.7585816067859098, log_density = -133.00762532757946, hamiltonian_energy = 267.4229299300528, hamiltonian_energy_error = -0.7420727886622558, max_hamiltonian_energy_error = 56.77757711966137, tree_depth = 7, numerical_error = false, step_size = 0.036935763677957666, nom_step_size = 0.036935763677957666, is_adapt = true), (n_steps = 18, is_accept = true, acceptance_rate = 0.05580400665630666, log_density = -130.34683000323056, hamiltonian_energy = 269.1226656704632, hamiltonian_energy_error = -0.29439359395058773, max_hamiltonian_energy_error = 1537.3975940119763, tree_depth = 4, numerical_error = true, step_size = 0.07973614355951876, nom_step_size = 0.07973614355951876, is_adapt = true), (n_steps = 127, is_accept = true, acceptance_rate = 0.662686567796459, log_density = -117.73974612174665, hamiltonian_energy = 233.6040766807583, hamiltonian_energy_error = 0.2481086781758961, max_hamiltonian_energy_error = 2.043068731916833, tree_depth = 7, numerical_error = false, step_size = 0.031034375189831147, nom_step_size = 0.031034375189831147, is_adapt = true), (n_steps = 63, is_accept = true, acceptance_rate = 0.0034846255359861898, log_density = -117.73974612174665, hamiltonian_energy = 249.50473887773313, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 166.3121569791229, tree_depth = 6, numerical_error = false, step_size = 0.05268497325241337, nom_step_size = 0.05268497325241337, is_adapt = true), (n_steps = 255, is_accept = true, acceptance_rate = 0.4849002206791085, log_density = -112.01382900076604, hamiltonian_energy = 234.44256073855752, hamiltonian_energy_error = -0.04346050313776573, max_hamiltonian_energy_error = 19.4562689285724, tree_depth = 8, numerical_error = false, step_size = 0.018382172852824747, nom_step_size = 0.018382172852824747, is_adapt = true), (n_steps = 255, is_accept = true, acceptance_rate = 0.6257899519687348, log_density = -127.67658784883785, hamiltonian_energy = 233.97625322812934, hamiltonian_energy_error = 0.04998293120291919, max_hamiltonian_energy_error = 3.372113207672726, tree_depth = 8, numerical_error = false, step_size = 0.02038800558595629, nom_step_size = 0.02038800558595629, is_adapt = true)])

Step 5: Plot diagnostics

Now let's make sure the fit is good. This can be done by looking at the chain mixing plot and the autocorrelation plot. First, let's create the chain mixing plot using the plot recipes from ????

samples = hcat(samples...)
samples_reduced = samples[1:5, :]
samples_reshape = reshape(samples_reduced, (500, 5, 1))
Chain_Spiral = Chains(samples_reshape)
plot(Chain_Spiral)

Now we check the autocorrelation plot:

autocorplot(Chain_Spiral)

As another diagnostic, let's check the result on retrodicted data. To do this, we generate solutions of the Neural ODE on samples of the neural network parameters, and check the results of the predictions against the data. Let's start by looking at the time series:

pl = scatter(tsteps, ode_data[1, :], color = :red, label = "Data: Var1", xlabel = "t",
             title = "Spiral Neural ODE")
scatter!(tsteps, ode_data[2, :], color = :blue, label = "Data: Var2")
for k in 1:300
    resol = predict_neuralode(samples[:, 100:end][:, rand(1:400)])
    plot!(tsteps, resol[1, :], alpha = 0.04, color = :red, label = "")
    plot!(tsteps, resol[2, :], alpha = 0.04, color = :blue, label = "")
end

losses = map(x -> loss_neuralode(x)[1], eachcol(samples))
idx = findmin(losses)[2]
prediction = predict_neuralode(samples[:, idx])
plot!(tsteps, prediction[1, :], color = :black, w = 2, label = "")
plot!(tsteps, prediction[2, :], color = :black, w = 2, label = "Best fit prediction",
      ylims = (-2.5, 3.5))

That showed the time series form. We can similarly do a phase-space plot:

pl = scatter(ode_data[1, :], ode_data[2, :], color = :red, label = "Data", xlabel = "Var1",
             ylabel = "Var2", title = "Spiral Neural ODE")
for k in 1:300
    resol = predict_neuralode(samples[:, 100:end][:, rand(1:400)])
    plot!(resol[1, :], resol[2, :], alpha = 0.04, color = :red, label = "")
end
plot!(prediction[1, :], prediction[2, :], color = :black, w = 2,
      label = "Best fit prediction", ylims = (-2.5, 3))