Neural Second Order Ordinary Differential Equation

The neural ODE focuses and finding a neural network such that:

\[u^\prime = NN(u)\]

However, often in physics-based modeling, the key object is not the velocity but the acceleration: knowing the acceleration tells you the force field and thus the generating process for the dynamical system. Thus what we want to do is find the force, i.e.:

\[u^{\prime\prime} = NN(u)\]

(Note that in order to be the acceleration, we should divide the output of the neural network by the mass!)

An example of training a neural network on a second order ODE is as follows:

import SciMLSensitivity as SMS
import OrdinaryDiffEq as ODE
import Lux
import Optimization as OPT
import OptimizationOptimisers as OPO
import RecursiveArrayTools
import Random
import ComponentArrays as CA

u0 = Float32[0.0; 2.0]
du0 = Float32[0.0; 0.0]
tspan = (0.0f0, 1.0f0)
t = range(tspan[1], tspan[2], length = 20)

model = Lux.Chain(Lux.Dense(2, 50, tanh), Lux.Dense(50, 2))
ps, st = Lux.setup(Random.default_rng(), model)
ps = CA.ComponentArray(ps)
model = Lux.StatefulLuxLayer{true}(model, ps, st)

ff(du, u, p, t) = model(u, p)
prob = ODE.SecondOrderODEProblem{false}(ff, du0, u0, tspan, ps)

function predict(p)
    Array(ODE.solve(prob, ODE.Tsit5(); p, saveat = t))
end

correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end])))

function loss_n_ode(p)
    pred = predict(p)
    sum(abs2, correct_pos .- pred[1:2, :])
end

l1 = loss_n_ode(ps)

callback = function (state, l)
    println(l)
    l < 0.01
end

adtype = OPT.AutoZygote()
optf = OPT.OptimizationFunction((x, p) -> loss_n_ode(x), adtype)
optprob = OPT.OptimizationProblem(optf, ps)

res = OPT.solve(optprob, OPO.Adam(0.01); callback, maxiters = 1000)
retcode: Default
u: ComponentVector{Float32}(layer_1 = (weight = Float32[0.24058375 -0.69518834; 1.2867389 -0.8374674; … ; -1.1122851 2.313341; -0.8960205 -2.4888058], bias = Float32[-0.91555095, -0.0022994024, -0.89587516, 1.5222236, -0.8090256, -0.6156897, 5.2697897, -2.8805118, 0.8574152, -0.0561467  …  0.40166214, -0.30738875, -1.2172376, -0.08559756, -0.7241604, 0.98692954, 1.9969997, -0.4760418, 0.42486984, -0.1576097]), layer_2 = (weight = Float32[-0.03962546 -0.32980138 … 0.2624971 -0.34845215; -0.6062042 -0.24213393 … 0.3440051 -0.6374107], bias = Float32[0.11345971, 0.4518554]))