Neural Second Order Ordinary Differential Equation

The neural ODE focuses and finding a neural network such that:

\[u^\prime = NN(u)\]

However, often in physics-based modeling, the key object is not the velocity but the acceleration: knowing the acceleration tells you the force field and thus the generating process for the dynamical system. Thus what we want to do is find the force, i.e.:

\[u^{\prime\prime} = NN(u)\]

(Note that in order to be the acceleration, we should divide the output of the neural network by the mass!)

An example of training a neural network on a second order ODE is as follows:

using SciMLSensitivity
using OrdinaryDiffEq, Lux, Optimization, OptimizationOptimisers, RecursiveArrayTools,
      Random, ComponentArrays

u0 = Float32[0.0; 2.0]
du0 = Float32[0.0; 0.0]
tspan = (0.0f0, 1.0f0)
t = range(tspan[1], tspan[2], length = 20)

model = Chain(Dense(2, 50, tanh), Dense(50, 2))
ps, st = Lux.setup(Random.default_rng(), model)
ps = ComponentArray(ps)
model = Lux.StatefulLuxLayer{true}(model, ps, st)

ff(du, u, p, t) = model(u, p)
prob = SecondOrderODEProblem{false}(ff, du0, u0, tspan, ps)

function predict(p)
    Array(solve(prob, Tsit5(), p = p, saveat = t))
end

correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end])))

function loss_n_ode(p)
    pred = predict(p)
    sum(abs2, correct_pos .- pred[1:2, :])
end

l1 = loss_n_ode(ps)

callback = function (state, l)
    println(l)
    l < 0.01
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_n_ode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)

res = Optimization.solve(optprob, Adam(0.01); callback = callback, maxiters = 1000)
retcode: Default
u: ComponentVector{Float32}(layer_1 = (weight = Float32[2.2224526 -1.5255408; 0.9521502 1.1772366; … ; -1.8326848 1.4084727; -1.7333459 2.1095374], bias = Float32[-0.15187342, 0.71245444, 2.183176, 1.6170838, 0.14138436, 0.4734877, -0.79073745, 0.69654995, -10.267073, 0.81067955  …  -0.23183286, 3.967198, -0.57653606, 0.32402316, -0.3662443, -6.6344995, 0.77488655, 0.06966324, 0.4423883, 0.14533088]), layer_2 = (weight = Float32[-0.34653428 0.16595936 … 0.09870499 0.34474167; -0.27406764 0.514115 … 0.4550232 0.30379498], bias = Float32[0.16379304, 0.30240667]))