Automatically Discover Missing Physics by Embedding Machine Learning into Differential Equations

One of the most time-consuming parts of modeling is building the model. How do you know when your model is correct? When you solve an inverse problem to calibrate your model to data, who you gonna call if there are no parameters that make the model the data? This is the problem that the Universal Differential Equation (UDE) approach solves: the ability to start from the model you have, and suggest minimal mechanistic extensions that would allow the model to fit the data. In this showcase, we will show how to take a partially correct model and auto-complete it to find the missing physics.

Note

For a scientific background on the universal differential equation approach, check out Universal Differential Equations for Scientific Machine Learning

Starting Point: The Packages To Use

There are many packages which are used as part of this showcase. Let's detail what they are and how they are used. For the neural network training:

ModuleDescription
OrdinaryDiffEq.jl (DifferentialEquations.jl)The numerical differential equation solvers
SciMLSensitivity.jlThe adjoint methods, defines gradients of ODE solvers
Optimization.jlThe optimization library
OptimizationOptimisers.jlThe optimization solver package with Adam
OptimizationOptimJL.jlThe optimization solver package with BFGS

For the symbolic model discovery:

ModuleDescription
ModelingToolkit.jlThe symbolic modeling environment
DataDrivenDiffEq.jlThe symbolic regression interface
DataDrivenSparse.jlThe sparse regression symbolic regression solvers
Zygote.jlThe automatic differentiation library for fast gradients

Julia standard libraries:

ModuleDescription
LinearAlgebraRequired for the norm function
StatisticsRequired for the mean function

And external libraries:

ModuleDescription
Lux.jlThe deep learning (neural network) framework
ComponentArrays.jlFor the ComponentArray type to match Lux to SciML
Plots.jlThe plotting and visualization library
StableRNGs.jlStable random seeding
Note

The deep learning framework Flux.jl could be used in place of Lux, though most tutorials in SciML generally prefer Lux.jl due to its explicit parameter interface, leading to nicer code. Both share the same internal implementations of core kernels, and thus have very similar feature support and performance.

# SciML Tools
using OrdinaryDiffEq, ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity, DataDrivenSparse
using Optimization, OptimizationOptimisers, OptimizationOptimJL

# Standard Libraries
using LinearAlgebra, Statistics

# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
gr()

# Set a random seed for reproducible behaviour
rng = StableRNG(1111)
StableRNGs.LehmerRNG(state=0x000000000000000000000000000008af)

Problem Setup

In order to know that we have automatically discovered the correct model, we will use generated data from a known model. This model will be the Lotka-Volterra equations. These equations are given by:

\[\begin{aligned} \frac{dx}{dt} &= \alpha x - \beta x y \\ \frac{dy}{dt} &= -\delta y + \gamma x y \\ \end{aligned}\]

This is a model of rabbits and wolves. $\alpha x$ is the exponential growth of rabbits in isolation, $-\beta x y$ and $\gamma x y$ are the interaction effects of wolves eating rabbits, and $-\delta y$ is the term for how wolves die hungry in isolation.

Now assume that we have never seen rabbits and wolves in the same room. We only know the two effects $\alpha x$ and $-\delta y$. Can we use Scientific Machine Learning to automatically discover an extension to what we already know? That is what we will solve with the universal differential equation.

Generating the Training Data

First, let's generate training data from the Lotka-Volterra equations. This is straightforward and standard DifferentialEquations.jl usage. Our sample data is thus generated as follows:

function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α * u[1] - β * u[2] * u[1]
    du[2] = γ * u[1] * u[2] - δ * u[2]
end

# Define the experimental parameter
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)

# Add noise in terms of the mean
X = Array(solution)
t = solution.t

x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))

plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])

Definition of the Universal Differential Equation

Now let's define our UDE. We will use Lux.jl to define the neural network as follows:

rbf(x) = exp.(-(x .^ 2))

# Multilayer FeedForward
const U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
              Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
const _st = st
(layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple())

We then define the UDE as a dynamical system that is u' = known(u) + NN(u) like:

# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
    û = U(u, p, _st)[1] # Network prediction
    du[1] = p_true[1] * u[1] + û[1]
    du[2] = -p_true[4] * u[2] + û[2]
end

# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 5.0)
u0: 2-element Vector{Float64}:
 3.1463924566781167
 1.5423300037202512

Notice that the most important part of this is that the neural network does not have hard-coded weights. The weights of the neural network are the parameters of the ODE system. This means that if we change the parameters of the ODE system, then we will have updated the internal neural networks to new weights. Keep that in mind for the next part.

... and tada: now we have a neural network integrated into our dynamical system!

Note

Even if the known physics is only approximate or correct, it can be helpful to improve the fitting process! Check out this JuliaCon talk which dives into this issue.

Setting Up the Training Loop

Now let's build a training loop around our UDE. First, let's make a function predict which runs our simulation at new neural network weights. Recall that the weights of the neural network are the parameters of the ODE, so what we need to do in predict is update our ODE to our new parameters and then run it.

For this update step, we will use the remake function from the SciMLProblem interface. remake works by specifying key = value pairs to update in the problem fields. Thus to update u0, we would add a keyword argument u0 = .... To update the parameters, we'd do p = .... The field names can be acquired from the problem documentation (or the docstrings!).

Knowing this, our predict function looks like:

function predict(θ, X = Xₙ[:, 1], T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = T,
                abstol = 1e-6, reltol = 1e-6,
                sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end
predict (generic function with 3 methods)

There are many choices for the combination of sensitivity algorithm and automatic differentiation library (see Choosing a Sensitivity Algorithm. For example, you could have used sensealg=ForwardDiffSensitivity().

Now, for our loss function, we solve the ODE at our new parameters and check its L2 loss against the dataset. Using our predict function, this looks like:

function loss(θ)
    X̂ = predict(θ)
    mean(abs2, Xₙ .- X̂)
end
loss (generic function with 1 method)

Lastly, what we will need to track our optimization is to define a callback as defined by the OptimizationProblem's solve interface. Because our function only returns one value, the loss l, the callback will be a function of the current parameters θ and l. Let's setup a callback prints every 50 steps the current loss:

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 50 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end
#1 (generic function with 1 method)

Training

Now we're ready to train! To run the training process, we will need to build an OptimizationProblem. Because we have a lot of parameters, we will use Zygote.jl. Optimization.jl makes the choice of automatic differentiation easy, just by specifying an adtype in the OptimizationFunction construction

Knowing this, we can build our OptimizationProblem as follows:

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
OptimizationProblem. In-place: true
u0: ComponentVector{Float64}(layer_1 = (weight = [0.49426865577697754 0.5692564249038696; 0.40171918272972107 -0.8665286302566528; … ; 0.47097498178482056 -0.7521204352378845; -0.20216092467308044 -0.3197280168533325], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = [0.6822634935379028 -0.6952740550041199 … 0.5011160969734192 0.24313241243362427; -0.39863723516464233 -0.17176461219787598 … -0.6159946322441101 0.18968746066093445; … ; -0.7200708389282227 -0.6787673234939575 … -0.5633968114852905 0.1658746749162674; 0.0014851824380457401 -0.10373303294181824 … 0.09008530527353287 -0.043933477252721786], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = [0.0011237671133130789 0.006483868230134249 … 0.2754976451396942 -0.2874394953250885; 0.04383227229118347 -0.32253962755203247 … 0.09472294896841049 -0.4210013747215271; … ; -0.5179172158241272 -0.6043259501457214 … -0.18625909090042114 0.06577149033546448; -0.2150842249393463 0.2565661072731018 … 0.5849692821502686 0.2193499207496643], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = [0.7144760489463806 0.4398183226585388 … -0.8286120891571045 0.04256918653845787; -0.5670844316482544 -0.3962741494178772 … 0.1667986661195755 0.8446723818778992], bias = [0.0; 0.0;;]))

Now... we optimize it. We will use a mixed strategy. First, let's do some iterations of ADAM because it's better at finding a good general area of parameter space, but then we will move to BFGS which will quickly hone in on a local minimum. Note that if we only use ADAM it will take a ton of iterations, and if we only use BFGS we normally end up in a bad local minimum, so this combination tends to be a good one for UDEs.

Thus we first solve the optimization problem with ADAM. Choosing a learning rate of 0.1 (tuned to be as high as possible that doesn't tend to make the loss shoot up), we see:

res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 5000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
Current loss after 50 iterations: 240907.39527636164
Current loss after 100 iterations: 171079.08407050956
Current loss after 150 iterations: 130867.55479750884
Current loss after 200 iterations: 108175.78322678556
Current loss after 250 iterations: 92683.0202358074
Current loss after 300 iterations: 80367.46213206719
Current loss after 350 iterations: 70059.52980713078
Current loss after 400 iterations: 61257.66807379558
Current loss after 450 iterations: 53651.070915226555
Current loss after 500 iterations: 47008.052552430934
Current loss after 550 iterations: 41158.82453863906
Current loss after 600 iterations: 35981.33890570629
Current loss after 650 iterations: 31385.755136724416
Current loss after 700 iterations: 27302.94152308521
Current loss after 750 iterations: 23677.217134984945
Current loss after 800 iterations: 20461.938111980275
Current loss after 850 iterations: 17616.78826869674
Current loss after 900 iterations: 15106.083338565531
Current loss after 950 iterations: 12897.68572817918
Current loss after 1000 iterations: 10962.29719320891
Current loss after 1050 iterations: 9272.991987918193
Current loss after 1100 iterations: 7804.903750333238
Current loss after 1150 iterations: 6535.01424387241
Current loss after 1200 iterations: 5442.008180921293
Current loss after 1250 iterations: 4506.17085728281
Current loss after 1300 iterations: 3709.3127759741087
Current loss after 1350 iterations: 3034.7097241543042
Current loss after 1400 iterations: 2467.0510196263317
Current loss after 1450 iterations: 1992.3899982085038
Current loss after 1500 iterations: 1598.0933737822477
Current loss after 1550 iterations: 1272.7869188266336
Current loss after 1600 iterations: 1006.2957398836103
Current loss after 1650 iterations: 789.5787601967985
Current loss after 1700 iterations: 614.6596742689085
Current loss after 1750 iterations: 474.56147626679444
Current loss after 1800 iterations: 363.2457929218382
Current loss after 1850 iterations: 275.54067813483533
Current loss after 1900 iterations: 207.0510286768174
Current loss after 1950 iterations: 154.06804715043148
Current loss after 2000 iterations: 113.48626198065796
Current loss after 2050 iterations: 82.7272847925857
Current loss after 2100 iterations: 59.66963998279686
Current loss after 2150 iterations: 42.58462732053778
Current loss after 2200 iterations: 30.078200141473843
Current loss after 2250 iterations: 21.038747645775523
Current loss after 2300 iterations: 14.590593523979056
Current loss after 2350 iterations: 10.052932968129664
Current loss after 2400 iterations: 6.903877023397368
Current loss after 2450 iterations: 4.749307883135153
Current loss after 2500 iterations: 3.296259609434076
Current loss after 2550 iterations: 2.33059473438387
Current loss after 2600 iterations: 1.698538434362526
Current loss after 2650 iterations: 1.2913041470380844
Current loss after 2700 iterations: 1.0329730649060505
Current loss after 2750 iterations: 0.871489295863934
Current loss after 2800 iterations: 0.7719550879241281
Current loss after 2850 iterations: 0.711458659118542
Current loss after 2900 iterations: 0.6752072380011963
Current loss after 2950 iterations: 0.653793520840879
Current loss after 3000 iterations: 0.6413244058263694
Current loss after 3050 iterations: 0.6341643539423514
Current loss after 3100 iterations: 0.6301054736206746
Current loss after 3150 iterations: 0.6278283341059617
Current loss after 3200 iterations: 0.6265574060455696
Current loss after 3250 iterations: 0.6258445582231248
Current loss after 3300 iterations: 0.6254353473179832
Current loss after 3350 iterations: 0.625187889095077
Current loss after 3400 iterations: 0.6250245202328235
Current loss after 3450 iterations: 0.6249035310811221
Current loss after 3500 iterations: 0.6248029509405473
Current loss after 3550 iterations: 0.6247114298186817
Current loss after 3600 iterations: 0.6246232147405189
Current loss after 3650 iterations: 0.624535439713398
Current loss after 3700 iterations: 0.6244466952063696
Current loss after 3750 iterations: 0.6243562897318248
Current loss after 3800 iterations: 0.624263877350417
Current loss after 3850 iterations: 0.6241692741588646
Current loss after 3900 iterations: 0.6240723700527201
Current loss after 3950 iterations: 0.6239730873477777
Current loss after 4000 iterations: 0.623871361864226
Current loss after 4050 iterations: 0.6237671345016246
Current loss after 4100 iterations: 0.6236603475828046
Current loss after 4150 iterations: 0.6235509433079527
Current loss after 4200 iterations: 0.6234388631147162
Current loss after 4250 iterations: 0.6233240474156136
Current loss after 4300 iterations: 0.6232064354895591
Current loss after 4350 iterations: 0.6230859654343467
Current loss after 4400 iterations: 0.6229625741402467
Current loss after 4450 iterations: 0.6228361972764498
Current loss after 4500 iterations: 0.6227067692761545
Current loss after 4550 iterations: 0.6225742233270604
Current loss after 4600 iterations: 0.6224384913609643
Current loss after 4650 iterations: 0.6222995040439978
Current loss after 4700 iterations: 0.6221571907692574
Current loss after 4750 iterations: 0.6220114796493855
Current loss after 4800 iterations: 0.6218622975081145
Current loss after 4850 iterations: 0.6217095698756032
Current loss after 4900 iterations: 0.62155322098354
Current loss after 4950 iterations: 0.6213931737619987
Current loss after 5000 iterations: 0.6212293498360424
Training loss after 5001 iterations: 0.6212293498360424

Now we use the optimization result of the first run as the initial condition of the second optimization, and run it with BFGS. This looks like:

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback = callback, maxiters = 1000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Rename the best candidate
p_trained = res2.u
ComponentVector{Float64}(layer_1 = (weight = [0.3692923226063261 0.46758176933255335; 0.08488342700219989 -0.7183115347969398; … ; -0.274850027470782 -1.2991822837398055; -0.3732479855630566 -0.5538472089475422], bias = [-0.687539993470236; 0.2592453737351555; … ; -0.5002472890172242; 0.24586516891480176;;]), layer_2 = (weight = [0.4798151915708458 -1.728830363995384 … 0.30964517482075843 0.13396616682233944; 0.3443673110407487 0.0960306913364615 … -0.2528383847866073 0.5102826208646842; … ; -0.4213787386322402 -0.9939163076015388 … -0.7870490038814332 -0.23632027783306617; 1.7991930184986202 0.3294710733553182 … 0.5926579263196514 0.7922781002331928], bias = [-1.271008685870409; -0.2691788134017319; … ; -0.5131689300095028; 0.39332445660212245;;]), layer_3 = (weight = [0.02494014774081038 0.0330845052434183 … 0.15736305746365845 -0.8623370581000009; 0.3882707880844404 0.03811186072718824 … 0.3520540366206427 0.10655966894000743; … ; -1.8177761301637982 -1.2009778904316168 … -0.04942890154109963 1.2186749690036984; -2.2580327873132515 -1.2936969898380632 … -0.757224477992818 -0.4258623721281885], bias = [-0.09427931096593128; 0.4740090250252679; … ; -0.9519662995678734; -1.4561141222283094;;]), layer_4 = (weight = [1.4338823127061604 0.06916708997819655 … -3.5390471091977083 -2.6412291887787593; -0.46717494454828556 -1.1304789212553283 … 3.572008675570136 3.58932603595197], bias = [-2.798843388819688; 2.1528882382569736;;]))

and bingo, we have a trained UDE.

Visualizing the Trained UDE

How well did our neural network do? Let's take a look:

# Plot the losses
pl_losses = plot(1:5000, losses[1:5000], yaxis = :log10, xaxis = :log10,
                 xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(5001:length(losses), losses[5001:end], yaxis = :log10, xaxis = :log10,
      xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)

Next, we compare the original data to the output of the UDE predictor. Note that we can even create more samples from the underlying model by simply adjusting the time steps!

## Analysis of the trained network
# Plot the data and the approximation
ts = first(solution.t):(mean(diff(solution.t)) / 2):last(solution.t)
X̂ = predict(p_trained, Xₙ[:, 1], ts)
# Trained on noisy data vs real solution
pl_trajectory = plot(ts, transpose(X̂), xlabel = "t", ylabel = "x(t), y(t)", color = :red,
                     label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])

Let's see how well the unknown term has been approximated:

# Ideal unknown interactions of the predictor
Ȳ = [-p_[2] * (X̂[1, :] .* X̂[2, :])'; p_[3] * (X̂[1, :] .* X̂[2, :])']
# Neural network guess
Ŷ = U(X̂, p_trained, st)[1]

pl_reconstruction = plot(ts, transpose(Ŷ), xlabel = "t", ylabel = "U(x,y)", color = :red,
                         label = ["UDE Approximation" nothing])
plot!(ts, transpose(Ȳ), color = :black, label = ["True Interaction" nothing])

And have a nice look at all the information:

# Plot the error
pl_reconstruction_error = plot(ts, norm.(eachcol(Ȳ - Ŷ)), yaxis = :log, xlabel = "t",
                               ylabel = "L2-Error", label = nothing, color = :red)
pl_missing = plot(pl_reconstruction, pl_reconstruction_error, layout = (2, 1))

pl_overall = plot(pl_trajectory, pl_missing)

That looks pretty good. And if we are happy with deep learning, we can leave it at that: we have trained a neural network to capture our missing dynamics.

But...

Can we also make it print out the LaTeX for what the missing equations were? Find out more after the break!

Symbolic regression via sparse regression (SINDy based)

This part of the showcase is still a work in progress... shame on us. But be back in a jiffy and we'll have it done.

Okay, that was a quick break, and that's good because this next part is pretty cool. Let's use DataDrivenDiffEq.jl to transform our trained neural network from machine learning mumbo jumbo into predictions of missing mechanistic equations. To do this, we first generate a symbolic basis that represents the space of mechanistic functions we believe this neural network should map to. Let's choose a bunch of polynomial functions:

@variables u[1:2]
b = polynomial_basis(u, 4)
basis = Basis(b, u);

\[ \begin{align} \varphi_1 =& 1 \\ \varphi_2 =& u_1 \\ \varphi_3 =& u_1^{2} \\ \varphi_4 =& u_1^{3} \\ \varphi_5 =& u_1^{4} \\ \varphi_6 =& u_2 \\ \varphi_7 =& u_1 u_2 \\ \varphi_8 =& u_1^{2} u_2 \\ \varphi_9 =& u_1^{3} u_2 \\ \varphi_{1 0} =& u_2^{2} \\ \varphi_{1 1} =& u_2^{2} u_1 \\ \varphi_{1 2} =& u_2^{2} u_1^{2} \\ \varphi_{1 3} =& u_2^{3} \\ \varphi_{1 4} =& u_2^{3} u_1 \\ \varphi_{1 5} =& u_2^{4} \end{align} \]

Now let's define our DataDrivenProblems for the sparse regressions. To assess the capability of the sparse regression, we will look at 3 cases:

  • What if we trained no neural network and tried to automatically uncover the equations from the original noisy data? This is the approach in the literature known as structural identification of dynamical systems (SINDy). We will call this the full problem. This will assess whether this incorporation of prior information was helpful.
  • What if we trained the neural network using the ideal right-hand side missing derivative functions? This is the value computed in the plots above as . This will tell us whether the symbolic discovery could work in ideal situations.
  • Do the symbolic regression directly on the function y = NN(x), i.e. the trained learned neural network. This is what we really want, and will tell us how to extend our known equations.

To define the full problem, we need to define a DataDrivenProblem that has the time series of the solution X, the time points of the solution t, and the derivative at each time point of the solution, obtained by the ODE solution's interpolation. We can just use an interpolation to get the derivative:

full_problem = ContinuousDataDrivenProblem(Xₙ, t)
Continuous DataDrivenProblem{Float64} ##DDProblem#12681 in 2 dimensions and 21 samples

Now for the other two symbolic regressions, we are regressing input/outputs of the missing terms, and thus we directly define the datasets as the input/output mappings like:

ideal_problem = DirectDataDrivenProblem(X̂, Ȳ)
nn_problem = DirectDataDrivenProblem(X̂, Ŷ)
Direct DataDrivenProblem{Float64} ##DDProblem#12683 in 2 dimensions and 41 samples

Let's solve the data-driven problems using sparse regression. We will use the ADMM method, which requires we define a set of shrinking cutoff values λ, and we do this like:

λ = exp10.(-3:0.01:3)
opt = ADMM(λ)
DataDrivenSparse.ADMM{Vector{Float64}, Float64}([0.001, 0.0010232929922807535, 0.0010471285480508996, 0.001071519305237606, 0.0010964781961431851, 0.001122018454301963, 0.0011481536214968829, 0.001174897554939529, 0.001202264434617413, 0.0012302687708123812  …  812.8305161640995, 831.7637711026708, 851.1380382023768, 870.9635899560806, 891.2509381337459, 912.0108393559096, 933.2543007969915, 954.992586021436, 977.2372209558112, 1000.0], 1.0)

This is one of many methods for sparse regression, consult the DataDrivenDiffEq.jl documentation for more information on the algorithm choices. Taking this, let's solve each of the sparse regressions:

options = DataDrivenCommonOptions(maxiters = 10_000,
                                  normalize = DataNormalization(ZScoreTransform),
                                  selector = bic, digits = 1,
                                  data_processing = DataProcessing(split = 0.9,
                                                                   batchsize = 30,
                                                                   shuffle = true,
                                                                   rng = StableRNG(1111)))

full_res = solve(full_problem, basis, opt, options = options)
full_eqs = get_basis(full_res)
println(full_res)
┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 19
└ @ MLUtils ~/.cache/julia-buildkite-plugin/depots/0183cc98-c3b4-4959-aaaa-6c0d5f351407/packages/MLUtils/n3C0h/src/batchview.jl:95
┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 19
└ @ MLUtils ~/.cache/julia-buildkite-plugin/depots/0183cc98-c3b4-4959-aaaa-6c0d5f351407/packages/MLUtils/n3C0h/src/batchview.jl:95
"DataDrivenSolution{Float64}" with 2 equations and 6 parameters.
Returncode: Success
Residual sum of squares: 197.89962654743113
options = DataDrivenCommonOptions(maxiters = 10_000,
                                  normalize = DataNormalization(ZScoreTransform),
                                  selector = bic, digits = 1,
                                  data_processing = DataProcessing(split = 0.9,
                                                                   batchsize = 30,
                                                                   shuffle = true,
                                                                   rng = StableRNG(1111)))

ideal_res = solve(ideal_problem, basis, opt, options = options)
ideal_eqs = get_basis(ideal_res)
println(ideal_res)
"DataDrivenSolution{Float64}" with 2 equations and 2 parameters.
Returncode: Success
Residual sum of squares: 6.131259743239102
options = DataDrivenCommonOptions(maxiters = 10_000,
                                  normalize = DataNormalization(ZScoreTransform),
                                  selector = bic, digits = 1,
                                  data_processing = DataProcessing(split = 0.9,
                                                                   batchsize = 30,
                                                                   shuffle = true,
                                                                   rng = StableRNG(1111)))

nn_res = solve(nn_problem, basis, opt, options = options)
nn_eqs = get_basis(nn_res)
println(nn_res)
"DataDrivenSolution{Float64}" with 2 equations and 4 parameters.
Returncode: Success
Residual sum of squares: 152.5697196106988

Note that we passed the identical options into each of the solve calls to get the same data for each call.

We already saw that the full problem has failed to identify the correct equations of motion. To have a closer look, we can inspect the corresponding equations:

for eqs in (full_eqs, ideal_eqs, nn_eqs)
    println(eqs)
    println(get_parameter_map(eqs))
    println()
end
Model ##Basis#12684 with 2 equations
States : u[1] u[2]
Parameters : 6
Independent variable: t
Equations
Differential(t)(u[1]) = p₁*(u[1]^2) + p₂*u[2] + p₄*(u[2]^4) + p₃*(u[2]^3)*u[1]
Differential(t)(u[2]) = p₅*(u[1]^2) + p₆*u[2]

Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}[p₁ => 0.4, p₂ => -0.3, p₃ => 0.1, p₄ => -0.1, p₅ => 0.1, p₆ => -1.0]

Model ##Basis#12688 with 2 equations
States : u[1] u[2]
Parameters : p₁ p₂
Independent variable: t
Equations
φ₁ = p₁*u[1]*u[2]
φ₂ = p₂*u[1]*u[2]

Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}[p₁ => -0.8, p₂ => 0.8]

Model ##Basis#12692 with 2 equations
States : u[1] u[2]
Parameters : p₁ p₂ p₃ p₄
Independent variable: t
Equations
φ₁ = p₁ + p₂*u[2] + p₃*u[1]*u[2]
φ₂ = p₄*u[1]*u[2]

Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}[p₁ => -0.7, p₂ => -0.2, p₃ => -0.2, p₄ => 0.6]

Next, we want to predict with our model. To do so, we embed the basis into a function like before:

# Define the recovered, hybrid model
function recovered_dynamics!(du, u, p, t)
    û = nn_eqs(u, p) # Recovered equations
    du[1] = p_[1] * u[1] + û[1]
    du[2] = -p_[4] * u[2] + û[2]
end

estimation_prob = ODEProblem(recovered_dynamics!, u0, tspan, get_parameter_values(nn_eqs))
estimate = solve(estimation_prob, Tsit5(), saveat = solution.t)

# Plot
plot(solution)
plot!(estimate)

We are still a bit off, so we fine tune the parameters by simply minimizing the residuals between the UDE predictor and our recovered parametrized equations:

function parameter_loss(p)
    Y = reduce(hcat, map(Base.Fix2(nn_eqs, p), eachcol(X̂)))
    sum(abs2, Ŷ .- Y)
end

optf = Optimization.OptimizationFunction((x, p) -> parameter_loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, get_parameter_values(nn_eqs))
parameter_res = Optimization.solve(optprob, Optim.LBFGS(), maxiters = 1000)
u: 4-element Vector{Float64}:
 -0.009122878092997957
  0.00047594146702109597
 -0.8974395933840341
  0.8004177239392938

Simulation

# Look at long term prediction
t_long = (0.0, 50.0)
estimation_prob = ODEProblem(recovered_dynamics!, u0, t_long, parameter_res)
estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.1) # Using higher tolerances here results in exit of julia
plot(estimate_long)
true_prob = ODEProblem(lotka!, u0, t_long, p_)
true_solution_long = solve(true_prob, Tsit5(), saveat = estimate_long.t)
plot!(true_solution_long)

Post Processing and Plots

c1 = 3 # RGBA(174/255,192/255,201/255,1) # Maroon
c2 = :orange # RGBA(132/255,159/255,173/255,1) # Red
c3 = :blue # RGBA(255/255,90/255,0,1) # Orange
c4 = :purple # RGBA(153/255,50/255,204/255,1) # Purple

p1 = plot(t, abs.(Array(solution) .- estimate)' .+ eps(Float32),
          lw = 3, yaxis = :log, title = "Timeseries of UODE Error",
          color = [3 :orange], xlabel = "t",
          label = ["x(t)" "y(t)"],
          titlefont = "Helvetica", legendfont = "Helvetica",
          legend = :topright)

# Plot L₂
p2 = plot3d(X̂[1, :], X̂[2, :], Ŷ[2, :], lw = 3,
            title = "Neural Network Fit of U2(t)", color = c1,
            label = "Neural Network", xaxis = "x", yaxis = "y",
            titlefont = "Helvetica", legendfont = "Helvetica",
            legend = :bottomright)
plot!(X̂[1, :], X̂[2, :], Ȳ[2, :], lw = 3, label = "True Missing Term", color = c2)

p3 = scatter(solution, color = [c1 c2], label = ["x data" "y data"],
             title = "Extrapolated Fit From Short Training Data",
             titlefont = "Helvetica", legendfont = "Helvetica",
             markersize = 5)

plot!(p3, true_solution_long, color = [c1 c2], linestyle = :dot, lw = 5,
      label = ["True x(t)" "True y(t)"])
plot!(p3, estimate_long, color = [c3 c4], lw = 1,
      label = ["Estimated x(t)" "Estimated y(t)"])
plot!(p3, [2.99, 3.01], [0.0, 10.0], lw = 1, color = :black, label = nothing)
annotate!([(1.5, 13, text("Training \nData", 10, :center, :top, :black, "Helvetica"))])
l = @layout [grid(1, 2)
             grid(1, 1)]
plot(p1, p2, p3, layout = l)
Timeseries of UODE Errorx(t)y(t)Neural Network Fit of U2(t)Neural NetworkTrue Missing TermExtrapolated Fit From Short Training Datax datay dataTrue x(t)True y(t)Estimated x(t)Estimated y(t)Training Data