Automatically Discover Missing Physics by Embedding Machine Learning into Differential Equations

One of the most time-consuming parts of modeling is building the model. How do you know when your model is correct? When you solve an inverse problem to calibrate your model to data, who you gonna call if there are no parameters that make the model the data? This is the problem that the Universal Differential Equation (UDE) approach solves: the ability to start from the model you have, and suggest minimal mechanistic extensions that would allow the model to fit the data. In this showcase, we will show how to take a partially correct model and auto-complete it to find the missing physics.

Note

For a scientific background on the universal differential equation approach, check out Universal Differential Equations for Scientific Machine Learning

Starting Point: The Packages To Use

There are many packages which are used as part of this showcase. Let's detail what they are and how they are used. For the neural network training:

ModuleDescription
OrdinaryDiffEq.jl (DifferentialEquations.jl)The numerical differential equation solvers
SciMLSensitivity.jlThe adjoint methods, defines gradients of ODE solvers
Optimization.jlThe optimization library
OptimizationOptimisers.jlThe optimization solver package with Adam
OptimizationOptimJL.jlThe optimization solver package with BFGS

For the symbolic model discovery:

ModuleDescription
ModelingToolkit.jlThe symbolic modeling environment
DataDrivenDiffEq.jlThe symbolic regression interface
DataDrivenSparse.jlThe sparse regression symbolic regression solvers
Zygote.jlThe automatic differentiation library for fast gradients

Julia standard libraries:

ModuleDescription
LinearAlgebraRequired for the norm function
StatisticsRequired for the mean function

And external libraries:

ModuleDescription
Lux.jlThe deep learning (neural network) framework
ComponentArrays.jlFor the ComponentArray type to match Lux to SciML
Plots.jlThe plotting and visualization library
StableRNGs.jlStable random seeding
Note

The deep learning framework Flux.jl could be used in place of Lux, though most tutorials in SciML generally prefer Lux.jl due to its explicit parameter interface, leading to nicer code. Both share the same internal implementations of core kernels, and thus have very similar feature support and performance.

# SciML Tools
using OrdinaryDiffEq, ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity, DataDrivenSparse
using Optimization, OptimizationOptimisers, OptimizationOptimJL

# Standard Libraries
using LinearAlgebra, Statistics

# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
gr()

# Set a random seed for reproducible behaviour
rng = StableRNG(1111)
StableRNGs.LehmerRNG(state=0x000000000000000000000000000008af)

Problem Setup

In order to know that we have automatically discovered the correct model, we will use generated data from a known model. This model will be the Lotka-Volterra equations. These equations are given by:

\[\begin{aligned} \frac{dx}{dt} &= \alpha x - \beta x y \\ \frac{dy}{dt} &= -\delta y + \gamma x y \\ \end{aligned}\]

This is a model of rabbits and wolves. $\alpha x$ is the exponential growth of rabbits in isolation, $-\beta x y$ and $\gamma x y$ are the interaction effects of wolves eating rabbits, and $-\delta y$ is the term for how wolves die hungry in isolation.

Now assume that we have never seen rabbits and wolves in the same room. We only know the two effects $\alpha x$ and $-\delta y$. Can we use Scientific Machine Learning to automatically discover an extension to what we already know? That is what we will solve with the universal differential equation.

Generating the Training Data

First, let's generate training data from the Lotka-Volterra equations. This is straightforward and standard DifferentialEquations.jl usage. Our sample data is thus generated as follows:

function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α * u[1] - β * u[2] * u[1]
    du[2] = γ * u[1] * u[2] - δ * u[2]
end

# Define the experimental parameter
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)

# Add noise in terms of the mean
X = Array(solution)
t = solution.t

x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))

plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])
Example block output

Definition of the Universal Differential Equation

Now let's define our UDE. We will use Lux.jl to define the neural network as follows:

rbf(x) = exp.(-(x .^ 2))

# Multilayer FeedForward
const U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
              Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
const _st = st
(layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple())

We then define the UDE as a dynamical system that is u' = known(u) + NN(u) like:

# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
    û = U(u, p, _st)[1] # Network prediction
    du[1] = p_true[1] * u[1] + û[1]
    du[2] = -p_true[4] * u[2] + û[2]
end

# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 5.0)
u0: 2-element Vector{Float64}:
 3.1463924566781167
 1.5423300037202512

Notice that the most important part of this is that the neural network does not have hard-coded weights. The weights of the neural network are the parameters of the ODE system. This means that if we change the parameters of the ODE system, then we will have updated the internal neural networks to new weights. Keep that in mind for the next part.

... and tada: now we have a neural network integrated into our dynamical system!

Note

Even if the known physics is only approximate or correct, it can be helpful to improve the fitting process! Check out this JuliaCon talk which dives into this issue.

Setting Up the Training Loop

Now let's build a training loop around our UDE. First, let's make a function predict which runs our simulation at new neural network weights. Recall that the weights of the neural network are the parameters of the ODE, so what we need to do in predict is update our ODE to our new parameters and then run it.

For this update step, we will use the remake function from the SciMLProblem interface. remake works by specifying key = value pairs to update in the problem fields. Thus to update u0, we would add a keyword argument u0 = .... To update the parameters, we'd do p = .... The field names can be acquired from the problem documentation (or the docstrings!).

Knowing this, our predict function looks like:

function predict(θ, X = Xₙ[:, 1], T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = T,
                abstol = 1e-6, reltol = 1e-6,
                sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end
predict (generic function with 3 methods)

There are many choices for the combination of sensitivity algorithm and automatic differentiation library (see Choosing a Sensitivity Algorithm. For example, you could have used sensealg=ForwardDiffSensitivity().

Now, for our loss function, we solve the ODE at our new parameters and check its L2 loss against the dataset. Using our predict function, this looks like:

function loss(θ)
    X̂ = predict(θ)
    mean(abs2, Xₙ .- X̂)
end
loss (generic function with 1 method)

Lastly, what we will need to track our optimization is to define a callback as defined by the OptimizationProblem's solve interface. Because our function only returns one value, the loss l, the callback will be a function of the current parameters θ and l. Let's setup a callback prints every 50 steps the current loss:

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 50 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end
#1 (generic function with 1 method)

Training

Now we're ready to train! To run the training process, we will need to build an OptimizationProblem. Because we have a lot of parameters, we will use Zygote.jl. Optimization.jl makes the choice of automatic differentiation easy, just by specifying an adtype in the OptimizationFunction construction

Knowing this, we can build our OptimizationProblem as follows:

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
OptimizationProblem. In-place: true
u0: ComponentVector{Float64}(layer_1 = (weight = [0.49426865577697754 0.5692564249038696; 0.40171918272972107 -0.8665286302566528; … ; 0.47097498178482056 -0.7521204352378845; -0.20216092467308044 -0.3197280168533325], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = [0.6822634935379028 -0.6952740550041199 … 0.5011160969734192 0.24313241243362427; -0.39863723516464233 -0.17176461219787598 … -0.6159946322441101 0.18968746066093445; … ; -0.7200708389282227 -0.6787673234939575 … -0.5633968114852905 0.1658746749162674; 0.0014851824380457401 -0.10373303294181824 … 0.09008530527353287 -0.043933477252721786], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = [0.0011237671133130789 0.006483868230134249 … 0.2754976451396942 -0.2874394953250885; 0.04383227229118347 -0.32253962755203247 … 0.09472294896841049 -0.4210013747215271; … ; -0.5179172158241272 -0.6043259501457214 … -0.18625909090042114 0.06577149033546448; -0.2150842249393463 0.2565661072731018 … 0.5849692821502686 0.2193499207496643], bias = [0.0; 0.0; … ; 0.0; 0.0;;]), layer_4 = (weight = [0.7144760489463806 0.4398183226585388 … -0.8286120891571045 0.04256918653845787; -0.5670844316482544 -0.3962741494178772 … 0.1667986661195755 0.8446723818778992], bias = [0.0; 0.0;;]))

Now... we optimize it. We will use a mixed strategy. First, let's do some iterations of ADAM because it's better at finding a good general area of parameter space, but then we will move to BFGS which will quickly hone in on a local minimum. Note that if we only use ADAM it will take a ton of iterations, and if we only use BFGS we normally end up in a bad local minimum, so this combination tends to be a good one for UDEs.

Thus we first solve the optimization problem with ADAM. Choosing a learning rate of 0.1 (tuned to be as high as possible that doesn't tend to make the loss shoot up), we see:

res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 5000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
Current loss after 50 iterations: 240907.40812047705
Current loss after 100 iterations: 171079.09448674775
Current loss after 150 iterations: 130867.59181984967
Current loss after 200 iterations: 108175.819196963
Current loss after 250 iterations: 92683.05198785299
Current loss after 300 iterations: 80367.48811919929
Current loss after 350 iterations: 70059.55070195843
Current loss after 400 iterations: 61257.68558271598
Current loss after 450 iterations: 53651.08580882383
Current loss after 500 iterations: 47008.06470439953
Current loss after 550 iterations: 41158.83468130469
Current loss after 600 iterations: 35981.34739953454
Current loss after 650 iterations: 31385.76155374214
Current loss after 700 iterations: 27302.94532524118
Current loss after 750 iterations: 23677.218526828085
Current loss after 800 iterations: 20461.937285293716
Current loss after 850 iterations: 17616.78554340633
Current loss after 900 iterations: 15106.07897130775
Current loss after 950 iterations: 12897.679899547724
Current loss after 1000 iterations: 10962.290260703332
Current loss after 1050 iterations: 9272.98412809502
Current loss after 1100 iterations: 7804.895149423114
Current loss after 1150 iterations: 6535.005113715568
Current loss after 1200 iterations: 5441.998713371786
Current loss after 1250 iterations: 4506.161272878186
Current loss after 1300 iterations: 3709.3032558332475
Current loss after 1350 iterations: 3034.7004579722848
Current loss after 1400 iterations: 2467.0421485271563
Current loss after 1450 iterations: 1992.3816194763508
Current loss after 1500 iterations: 1598.0855615495034
Current loss after 1550 iterations: 1272.779724729409
Current loss after 1600 iterations: 1006.2891966016309
Current loss after 1650 iterations: 789.5728737985675
Current loss after 1700 iterations: 614.6544329407017
Current loss after 1750 iterations: 474.55685746572834
Current loss after 1800 iterations: 363.24176465010794
Current loss after 1850 iterations: 275.53720126745895
Current loss after 1900 iterations: 207.04805971367972
Current loss after 1950 iterations: 154.0655400416435
Current loss after 2000 iterations: 113.48416980378501
Current loss after 2050 iterations: 82.72556007730252
Current loss after 2100 iterations: 59.66823649918496
Current loss after 2150 iterations: 42.58350060323642
Current loss after 2200 iterations: 30.077308343095584
Current loss after 2250 iterations: 21.038052240355587
Current loss after 2300 iterations: 14.590059660943801
Current loss after 2350 iterations: 10.052529706815138
Current loss after 2400 iterations: 6.903577440884438
Current loss after 2450 iterations: 4.749089044279981
Current loss after 2500 iterations: 3.2961024407445234
Current loss after 2550 iterations: 2.3304838042605582
Current loss after 2600 iterations: 1.6984615277477826
Current loss after 2650 iterations: 1.2912517692815566
Current loss after 2700 iterations: 1.0329379753413872
Current loss after 2750 iterations: 0.8714661234018667
Current loss after 2800 iterations: 0.7719399734563088
Current loss after 2850 iterations: 0.711448897799641
Current loss after 2900 iterations: 0.6752009717255759
Current loss after 2950 iterations: 0.6537894969037689
Current loss after 3000 iterations: 0.6413217952872436
Current loss after 3050 iterations: 0.6341626182863023
Current loss after 3100 iterations: 0.6301042695676299
Current loss after 3150 iterations: 0.6278274469818137
Current loss after 3200 iterations: 0.6265567041917336
Current loss after 3250 iterations: 0.6258439624362605
Current loss after 3300 iterations: 0.6254348108885727
Current loss after 3350 iterations: 0.6251873850310572
Current loss after 3400 iterations: 0.6250240332491129
Current loss after 3450 iterations: 0.6249030526946996
Current loss after 3500 iterations: 0.624802476533366
Current loss after 3550 iterations: 0.6247109569213214
Current loss after 3600 iterations: 0.6246227420503063
Current loss after 3650 iterations: 0.6245349665457077
Current loss after 3700 iterations: 0.6244462211962022
Current loss after 3750 iterations: 0.6243558146749706
Current loss after 3800 iterations: 0.6242634011207508
Current loss after 3850 iterations: 0.6241687966665521
Current loss after 3900 iterations: 0.624071891224429
Current loss after 3950 iterations: 0.6239726071163955
Current loss after 4000 iterations: 0.6238708801638934
Current loss after 4050 iterations: 0.6237666512655885
Current loss after 4100 iterations: 0.6236598627430592
Current loss after 4150 iterations: 0.6235504567946982
Current loss after 4200 iterations: 0.6234383748556598
Current loss after 4250 iterations: 0.6233235573360533
Current loss after 4300 iterations: 0.6232059435124488
Current loss after 4350 iterations: 0.6230854714797823
Current loss after 4400 iterations: 0.6229620781264884
Current loss after 4450 iterations: 0.6228356991184738
Current loss after 4500 iterations: 0.6227062688865883
Current loss after 4550 iterations: 0.6225737206156806
Current loss after 4600 iterations: 0.6224379862346241
Current loss after 4650 iterations: 0.622298996407113
Current loss after 4700 iterations: 0.6221566805232558
Current loss after 4750 iterations: 0.6220109666919613
Current loss after 4800 iterations: 0.6218617817342119
Current loss after 4850 iterations: 0.6217090511772592
Current loss after 4900 iterations: 0.6215526992498388
Current loss after 4950 iterations: 0.6213926488784575
Current loss after 5000 iterations: 0.6212288216848414
Training loss after 5001 iterations: 0.6212288216848414

Now we use the optimization result of the first run as the initial condition of the second optimization, and run it with BFGS. This looks like:

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback = callback, maxiters = 1000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Rename the best candidate
p_trained = res2.u
ComponentVector{Float64}(layer_1 = (weight = [0.1744960490842674 0.8512214833141024; 0.5727050228616578 -0.8557480953169814; … ; -0.3432193514714526 -1.2250882710295288; -0.40821670641714963 -0.5488044697452], bias = [-0.4022096466141816; -0.0020656548176875475; … ; -0.4044705507940684; -0.01669230214715999;;]), layer_2 = (weight = [0.8386650170733406 -1.0643415745166165 … 0.4126666603674607 0.04009036686626749; 0.14366827185847705 0.24027801000546492 … -0.15675553083465196 0.32434225575986186; … ; -0.1498593499841619 -0.7437779064260305 … -0.7235213408950851 0.08511926522436007; 1.257960896007518 0.0709747923637148 … 0.3934099035291271 0.2123872824502117], bias = [-0.6832104472616063; 0.3401376390509885; … ; -0.0809759989924102; 0.23589570026407014;;]), layer_3 = (weight = [-0.17091545156920163 -0.27964630526716 … 0.008663458328353841 -0.6760580529771697; 0.41657754791170587 0.05147028669114708 … 0.39746079202090323 0.02976456244672703; … ; -1.2945919799835717 -0.6153073011029941 … 0.1264245558690705 0.42059006257110326; -0.8556618667630123 -0.34419423729446874 … -0.05236646022020204 0.08333252972046083], bias = [-0.2906090961177651; 0.4318126972913906; … ; -0.37897589399451187; -0.3996705607289775;;]), layer_4 = (weight = [0.5569707998158673 0.1459364320376618 … -2.6460339561356094 -1.8467408220024184; -0.3794368023577222 -0.20755932313113748 … 0.8439231965882079 1.5777074084310143], bias = [-1.5495352724041693; 0.6273089145021222;;]))

and bingo, we have a trained UDE.

Visualizing the Trained UDE

How well did our neural network do? Let's take a look:

# Plot the losses
pl_losses = plot(1:5000, losses[1:5000], yaxis = :log10, xaxis = :log10,
                 xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(5001:length(losses), losses[5001:end], yaxis = :log10, xaxis = :log10,
      xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
Example block output

Next, we compare the original data to the output of the UDE predictor. Note that we can even create more samples from the underlying model by simply adjusting the time steps!

## Analysis of the trained network
# Plot the data and the approximation
ts = first(solution.t):(mean(diff(solution.t)) / 2):last(solution.t)
X̂ = predict(p_trained, Xₙ[:, 1], ts)
# Trained on noisy data vs real solution
pl_trajectory = plot(ts, transpose(X̂), xlabel = "t", ylabel = "x(t), y(t)", color = :red,
                     label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])
Example block output

Let's see how well the unknown term has been approximated:

# Ideal unknown interactions of the predictor
Ȳ = [-p_[2] * (X̂[1, :] .* X̂[2, :])'; p_[3] * (X̂[1, :] .* X̂[2, :])']
# Neural network guess
Ŷ = U(X̂, p_trained, st)[1]

pl_reconstruction = plot(ts, transpose(Ŷ), xlabel = "t", ylabel = "U(x,y)", color = :red,
                         label = ["UDE Approximation" nothing])
plot!(ts, transpose(Ȳ), color = :black, label = ["True Interaction" nothing])
Example block output

And have a nice look at all the information:

# Plot the error
pl_reconstruction_error = plot(ts, norm.(eachcol(Ȳ - Ŷ)), yaxis = :log, xlabel = "t",
                               ylabel = "L2-Error", label = nothing, color = :red)
pl_missing = plot(pl_reconstruction, pl_reconstruction_error, layout = (2, 1))

pl_overall = plot(pl_trajectory, pl_missing)
Example block output

That looks pretty good. And if we are happy with deep learning, we can leave it at that: we have trained a neural network to capture our missing dynamics.

But...

Can we also make it print out the LaTeX for what the missing equations were? Find out more after the break!

Symbolic regression via sparse regression (SINDy based)

This part of the showcase is still a work in progress... shame on us. But be back in a jiffy and we'll have it done.

Okay, that was a quick break, and that's good because this next part is pretty cool. Let's use DataDrivenDiffEq.jl to transform our trained neural network from machine learning mumbo jumbo into predictions of missing mechanistic equations. To do this, we first generate a symbolic basis that represents the space of mechanistic functions we believe this neural network should map to. Let's choose a bunch of polynomial functions:

@variables u[1:2]
b = polynomial_basis(u, 4)
basis = Basis(b, u);

\[ \begin{align} \varphi_1 =& 1 \\ \varphi_2 =& u_1 \\ \varphi_3 =& u_1^{2} \\ \varphi_4 =& u_1^{3} \\ \varphi_5 =& u_1^{4} \\ \varphi_6 =& u_2 \\ \varphi_7 =& u_1 u_2 \\ \varphi_8 =& u_1^{2} u_2 \\ \varphi_9 =& u_1^{3} u_2 \\ \varphi_{1 0} =& u_2^{2} \\ \varphi_{1 1} =& u_2^{2} u_1 \\ \varphi_{1 2} =& u_2^{2} u_1^{2} \\ \varphi_{1 3} =& u_2^{3} \\ \varphi_{1 4} =& u_2^{3} u_1 \\ \varphi_{1 5} =& u_2^{4} \end{align} \]

Now let's define our DataDrivenProblems for the sparse regressions. To assess the capability of the sparse regression, we will look at 3 cases:

  • What if we trained no neural network and tried to automatically uncover the equations from the original noisy data? This is the approach in the literature known as structural identification of dynamical systems (SINDy). We will call this the full problem. This will assess whether this incorporation of prior information was helpful.
  • What if we trained the neural network using the ideal right-hand side missing derivative functions? This is the value computed in the plots above as . This will tell us whether the symbolic discovery could work in ideal situations.
  • Do the symbolic regression directly on the function y = NN(x), i.e. the trained learned neural network. This is what we really want, and will tell us how to extend our known equations.

To define the full problem, we need to define a DataDrivenProblem that has the time series of the solution X, the time points of the solution t, and the derivative at each time point of the solution, obtained by the ODE solution's interpolation. We can just use an interpolation to get the derivative:

full_problem = ContinuousDataDrivenProblem(Xₙ, t)
Continuous DataDrivenProblem{Float64} ##DDProblem#190789 in 2 dimensions and 21 samples

Now for the other two symbolic regressions, we are regressing input/outputs of the missing terms, and thus we directly define the datasets as the input/output mappings like:

ideal_problem = DirectDataDrivenProblem(X̂, Ȳ)
nn_problem = DirectDataDrivenProblem(X̂, Ŷ)
Direct DataDrivenProblem{Float64} ##DDProblem#190791 in 2 dimensions and 41 samples

Let's solve the data-driven problems using sparse regression. We will use the ADMM method, which requires we define a set of shrinking cutoff values λ, and we do this like:

λ = exp10.(-3:0.01:3)
opt = ADMM(λ)
DataDrivenSparse.ADMM{Vector{Float64}, Float64}([0.001, 0.0010232929922807535, 0.0010471285480508996, 0.001071519305237606, 0.0010964781961431851, 0.001122018454301963, 0.0011481536214968829, 0.001174897554939529, 0.001202264434617413, 0.0012302687708123812  …  812.8305161640995, 831.7637711026708, 851.1380382023768, 870.9635899560806, 891.2509381337459, 912.0108393559096, 933.2543007969915, 954.992586021436, 977.2372209558112, 1000.0], 1.0)

This is one of many methods for sparse regression, consult the DataDrivenDiffEq.jl documentation for more information on the algorithm choices. Taking this, let's solve each of the sparse regressions:

options = DataDrivenCommonOptions(maxiters = 10_000,
                                  normalize = DataNormalization(ZScoreTransform),
                                  selector = bic, digits = 1,
                                  data_processing = DataProcessing(split = 0.9,
                                                                   batchsize = 30,
                                                                   shuffle = true,
                                                                   rng = StableRNG(1111)))

full_res = solve(full_problem, basis, opt, options = options)
full_eqs = get_basis(full_res)
println(full_res)
┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 19
@ MLUtils ~/.cache/julia-buildkite-plugin/depots/0183cc98-c3b4-4959-aaaa-6c0d5f351407/packages/MLUtils/n3C0h/src/batchview.jl:95
┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 19
@ MLUtils ~/.cache/julia-buildkite-plugin/depots/0183cc98-c3b4-4959-aaaa-6c0d5f351407/packages/MLUtils/n3C0h/src/batchview.jl:95
"DataDrivenSolution{Float64}" with 2 equations and 6 parameters.
Returncode: Success
Residual sum of squares: 197.8996265474311
options = DataDrivenCommonOptions(maxiters = 10_000,
                                  normalize = DataNormalization(ZScoreTransform),
                                  selector = bic, digits = 1,
                                  data_processing = DataProcessing(split = 0.9,
                                                                   batchsize = 30,
                                                                   shuffle = true,
                                                                   rng = StableRNG(1111)))

ideal_res = solve(ideal_problem, basis, opt, options = options)
ideal_eqs = get_basis(ideal_res)
println(ideal_res)
"DataDrivenSolution{Float64}" with 2 equations and 2 parameters.
Returncode: Success
Residual sum of squares: 8.965497669096854
options = DataDrivenCommonOptions(maxiters = 10_000,
                                  normalize = DataNormalization(ZScoreTransform),
                                  selector = bic, digits = 1,
                                  data_processing = DataProcessing(split = 0.9,
                                                                   batchsize = 30,
                                                                   shuffle = true,
                                                                   rng = StableRNG(1111)))

nn_res = solve(nn_problem, basis, opt, options = options)
nn_eqs = get_basis(nn_res)
println(nn_res)
"DataDrivenSolution{Float64}" with 2 equations and 4 parameters.
Returncode: Success
Residual sum of squares: 12.326551401713267

Note that we passed the identical options into each of the solve calls to get the same data for each call.

We already saw that the full problem has failed to identify the correct equations of motion. To have a closer look, we can inspect the corresponding equations:

for eqs in (full_eqs, ideal_eqs, nn_eqs)
    println(eqs)
    println(get_parameter_map(eqs))
    println()
end
Model ##Basis#190792 with 2 equations
States : u[1] u[2]
Parameters : 6
Independent variable: t
Equations
Differential(t)(u[1]) = p₂*u[2] + p₁*(u[1]^2) + p₃*u[1]*(u[2]^3) + p₄*(u[2]^4)
Differential(t)(u[2]) = p₆*u[2] + p₅*(u[1]^2)

Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}[p₁ => 0.4, p₂ => -0.3, p₃ => 0.1, p₄ => -0.1, p₅ => 0.1, p₆ => -1.0]

Model ##Basis#190796 with 2 equations
States : u[1] u[2]
Parameters : p₁ p₂
Independent variable: t
Equations
φ₁ = p₁*u[1]*u[2]
φ₂ = p₂*u[1]*u[2]

Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}[p₁ => -0.8, p₂ => 0.7]

Model ##Basis#190800 with 2 equations
States : u[1] u[2]
Parameters : p₁ p₂ p₃ p₄
Independent variable: t
Equations
φ₁ = p₁*u[1] + p₂*u[2] + p₃*(u[2]^4)
φ₂ = p₄*u[2]

Pair{SymbolicUtils.BasicSymbolic{Real}, Float64}[p₁ => 0.5, p₂ => -4.9, p₃ => 0.2, p₄ => 2.1]

Next, we want to predict with our model. To do so, we embed the basis into a function like before:

# Define the recovered, hybrid model
function recovered_dynamics!(du, u, p, t)
    û = nn_eqs(u, p) # Recovered equations
    du[1] = p_[1] * u[1] + û[1]
    du[2] = -p_[4] * u[2] + û[2]
end

estimation_prob = ODEProblem(recovered_dynamics!, u0, tspan, get_parameter_values(nn_eqs))
estimate = solve(estimation_prob, Tsit5(), saveat = solution.t)

# Plot
plot(solution)
plot!(estimate)
Example block output

We are still a bit off, so we fine tune the parameters by simply minimizing the residuals between the UDE predictor and our recovered parametrized equations:

function parameter_loss(p)
    Y = reduce(hcat, map(Base.Fix2(nn_eqs, p), eachcol(X̂)))
    sum(abs2, Ŷ .- Y)
end

optf = Optimization.OptimizationFunction((x, p) -> parameter_loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, get_parameter_values(nn_eqs))
parameter_res = Optimization.solve(optprob, Optim.LBFGS(), maxiters = 1000)
u: 4-element Vector{Float64}:
  0.656707317258741
 -5.119552308960921
  0.25211326016765234
  1.7168381488137645

Simulation

# Look at long term prediction
t_long = (0.0, 50.0)
estimation_prob = ODEProblem(recovered_dynamics!, u0, t_long, parameter_res)
estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.1) # Using higher tolerances here results in exit of julia
plot(estimate_long)
Example block output
true_prob = ODEProblem(lotka!, u0, t_long, p_)
true_solution_long = solve(true_prob, Tsit5(), saveat = estimate_long.t)
plot!(true_solution_long)
Example block output

Post Processing and Plots

c1 = 3 # RGBA(174/255,192/255,201/255,1) # Maroon
c2 = :orange # RGBA(132/255,159/255,173/255,1) # Red
c3 = :blue # RGBA(255/255,90/255,0,1) # Orange
c4 = :purple # RGBA(153/255,50/255,204/255,1) # Purple

p1 = plot(t, abs.(Array(solution) .- estimate)' .+ eps(Float32),
          lw = 3, yaxis = :log, title = "Timeseries of UODE Error",
          color = [3 :orange], xlabel = "t",
          label = ["x(t)" "y(t)"],
          titlefont = "Helvetica", legendfont = "Helvetica",
          legend = :topright)

# Plot L₂
p2 = plot3d(X̂[1, :], X̂[2, :], Ŷ[2, :], lw = 3,
            title = "Neural Network Fit of U2(t)", color = c1,
            label = "Neural Network", xaxis = "x", yaxis = "y",
            titlefont = "Helvetica", legendfont = "Helvetica",
            legend = :bottomright)
plot!(X̂[1, :], X̂[2, :], Ȳ[2, :], lw = 3, label = "True Missing Term", color = c2)

p3 = scatter(solution, color = [c1 c2], label = ["x data" "y data"],
             title = "Extrapolated Fit From Short Training Data",
             titlefont = "Helvetica", legendfont = "Helvetica",
             markersize = 5)

plot!(p3, true_solution_long, color = [c1 c2], linestyle = :dot, lw = 5,
      label = ["True x(t)" "True y(t)"])
plot!(p3, estimate_long, color = [c3 c4], lw = 1,
      label = ["Estimated x(t)" "Estimated y(t)"])
plot!(p3, [2.99, 3.01], [0.0, 10.0], lw = 1, color = :black, label = nothing)
annotate!([(1.5, 13, text("Training \nData", 10, :center, :top, :black, "Helvetica"))])
l = @layout [grid(1, 2)
             grid(1, 1)]
plot(p1, p2, p3, layout = l)
Example block output