Neural Second Order Ordinary Differential Equation
The neural ODE focuses and finding a neural network such that:
\[u^\prime = NN(u)\]
However, often in physics-based modeling, the key object is not the velocity but the acceleration: knowing the acceleration tells you the force field and thus the generating process for the dynamical system. Thus what we want to do is find the force, i.e.:
\[u^{\prime\prime} = NN(u)\]
(Note that in order to be the acceleration, we should divide the output of the neural network by the mass!)
An example of training a neural network on a second order ODE is as follows:
import SciMLSensitivity as SMS
import OrdinaryDiffEq as ODE
import Lux
import Optimization as OPT
import OptimizationOptimisers as OPO
import RecursiveArrayTools
import Random
import ComponentArrays as CA
u0 = Float32[0.0; 2.0]
du0 = Float32[0.0; 0.0]
tspan = (0.0f0, 1.0f0)
t = range(tspan[1], tspan[2], length = 20)
model = Lux.Chain(Lux.Dense(2, 50, tanh), Lux.Dense(50, 2))
ps, st = Lux.setup(Random.default_rng(), model)
ps = CA.ComponentArray(ps)
model = Lux.StatefulLuxLayer{true}(model, ps, st)
ff(du, u, p, t) = model(u, p)
prob = ODE.SecondOrderODEProblem{false}(ff, du0, u0, tspan, ps)
function predict(p)
Array(ODE.solve(prob, ODE.Tsit5(), p = p, saveat = t))
end
correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:1)[2:end])))
function loss_n_ode(p)
pred = predict(p)
sum(abs2, correct_pos .- pred[1:2, :])
end
l1 = loss_n_ode(ps)
callback = function (state, l)
println(l)
l < 0.01
end
adtype = OPT.AutoZygote()
optf = OPT.OptimizationFunction((x, p) -> loss_n_ode(x), adtype)
optprob = OPT.OptimizationProblem(optf, ps)
res = OPT.solve(optprob, OPO.Adam(0.01); callback = callback, maxiters = 1000)retcode: Default
u: ComponentVector{Float32}(layer_1 = (weight = Float32[0.04998351 1.6278027; 4.6389003 0.09846422; … ; -0.7548546 2.2277224; -0.25440887 -0.28429368], bias = Float32[0.5180408, -1.1272689, 0.6087178, 1.004539, 3.6613662, -0.042454503, 0.4089256, 2.0359526, 4.3982725, -0.13539667 … 0.92238814, 0.15022308, 0.9492686, 0.3483202, 0.9655811, -0.6161149, 0.2027376, -0.11294831, -0.002530006, -0.9999267]), layer_2 = (weight = Float32[0.18682925 0.2456791 … 0.25330502 -0.16010617; 0.31427196 0.27586848 … 0.44288948 -0.8731152], bias = Float32[0.2866435, 0.54418314]))