Neural Second Order Ordinary Differential Equation
The neural ODE focuses and finding a neural network such that:
\[u^\prime = NN(u)\]
However, often in physics-based modeling, the key object is not the velocity but the acceleration: knowing the acceleration tells you the force field and thus the generating process for the dynamical system. Thus what we want to do is find the force, i.e.:
\[u^{\prime\prime} = NN(u)\]
(Note that in order to be the acceleration, we should divide the output of the neural network by the mass!)
An example of training a neural network on a second order ODE is as follows:
using SciMLSensitivity
using OrdinaryDiffEq, Lux, Optimization, OptimizationOptimisers, RecursiveArrayTools,
Random, ComponentArrays
u0 = Float32[0.0; 2.0]
du0 = Float32[0.0; 0.0]
tspan = (0.0f0, 1.0f0)
t = range(tspan[1], tspan[2], length = 20)
model = Chain(Dense(2, 50, tanh), Dense(50, 2))
ps, st = Lux.setup(Random.default_rng(), model)
ps = ComponentArray(ps)
model = Lux.StatefulLuxLayer{true}(model, ps, st)
ff(du, u, p, t) = model(u, p)
prob = SecondOrderODEProblem{false}(ff, du0, u0, tspan, ps)
function predict(p)
Array(solve(prob, Tsit5(), p = p, saveat = t))
end
correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end])))
function loss_n_ode(p)
pred = predict(p)
sum(abs2, correct_pos .- pred[1:2, :])
end
l1 = loss_n_ode(ps)
callback = function (state, l)
println(l)
l < 0.01
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_n_ode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
res = Optimization.solve(optprob, Adam(0.01); callback = callback, maxiters = 1000)
retcode: Default
u: ComponentVector{Float32}(layer_1 = (weight = Float32[-0.62671167 0.6622869; 1.2354777 -0.90417004; … ; 0.8087391 -0.5034691; 3.1097383 0.62398845], bias = Float32[0.9771382, 0.19983982, -0.850714, -0.60511595, 0.3184396, -0.37336037, -0.6794703, -1.0875431, 1.3552002, 1.0313714 … -3.8514602, -2.9924679, 0.1390861, -0.46262506, 0.4204854, 9.019197, 6.0931153, 0.14336355, -0.7673417, -0.6548785]), layer_2 = (weight = Float32[-0.0827556 -0.2355678 … 0.1760262 -0.2185145; 0.4430437 -0.73740524 … -0.40070176 -0.54750234], bias = Float32[0.2458122, 0.37375915]))