Universal Differential Equations for Neural Feedback Control

You can also mix a known differential equation and a neural differential equation, so that the parameters and the neural network are estimated simultaneously!

We will assume that we know the dynamics of the second equation (linear dynamics), and our goal is to find a neural network that is dependent on the current state of the dynamical system that will control the second equation to stay close to 1.

using Lux, Optimization, OptimizationPolyalgorithms, ComponentArrays,
      SciMLSensitivity, Zygote, OrdinaryDiffEq, Plots, Random

rng = Random.default_rng()
u0 = [1.1]
tspan = (0.0, 25.0)
tsteps = 0.0:1.0:25.0

model_univ = Chain(Dense(2, 16, tanh), Dense(16, 16, tanh), Dense(16, 1))
ps, st = Lux.setup(Random.default_rng(), model_univ)
p_model = ComponentArray(ps)

# Parameters of the second equation (linear dynamics)
p_system = [0.5, -0.5]
p_all = ComponentArray(; p_model, p_system)
θ = ComponentArray(; u0, p_all)

function dudt_univ!(du, u, p, t)
    # Destructure the parameters
    model_weights = p.p_model
    α, β = p.p_system

    # The neural network outputs a control taken by the system
    # The system then produces an output
    model_control, system_output = u

    # Dynamics of the control and system
    dmodel_control = first(model_univ(u, model_weights, st))[1]
    dsystem_output = α * system_output + β * model_control

    # Update in place
    du[1] = dmodel_control
    du[2] = dsystem_output
end

prob_univ = ODEProblem(dudt_univ!, [0.0, u0[1]], tspan, p_all)
sol_univ = solve(prob_univ, Tsit5(), abstol = 1e-8, reltol = 1e-6)

function predict_univ(θ)
    return Array(solve(prob_univ, Tsit5(), u0 = [0.0, θ.u0[1]], p = θ.p_all,
        sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)),
        saveat = tsteps))
end

loss_univ(θ) = sum(abs2, predict_univ(θ)[2, :] .- 1)
l = loss_univ(θ)
3.038171482368016e9
list_plots = []
iter = 0
cb = function (state, l; makeplot = false)
    global list_plots, iter

    if iter == 0
        list_plots = []
    end
    iter += 1

    println(l)

    if makeplot
        plt = plot(predict_univ(state.u)', ylim = (0, 6))
        push!(list_plots, plt)
        display(plt)
    end
    return false
end
#1 (generic function with 1 method)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_univ(x), adtype)
optprob = Optimization.OptimizationProblem(optf, θ)
result_univ = Optimization.solve(optprob, PolyOpt(), callback = cb)
retcode: Success
u: ComponentVector{Float64}(u0 = [0.9999999958958983], p_all = (p_model = (layer_1 = (weight = [-0.6268315729879184 0.09737459130269115; -1.1327056095312253 1.718324615497645; … ; 1.9595110057276175 -0.7377482970520157; -0.8972949705265623 1.354379680507579], bias = [-0.45195410906274736, 0.46493466899546154, 0.0032193119391927152, 0.4313717026195735, -0.21462449294412067, 0.061987912796726094, -0.2780215395353777, -0.42721028335307587, -0.6255181870681181, 0.08542038593263945, 0.7633498127637038, -0.10921201913851084, 0.04360308835655556, -0.6135688262182826, 0.5777822853443565, -0.021270687507752926]), layer_2 = (weight = [0.6237999812263837 0.19820704145214257 … 0.08392489676132138 -0.6830969585558623; -0.10064382493204786 0.13767788441753812 … 0.21388948594276963 0.5160522460726028; … ; 0.6159178262386031 0.3241603831761014 … -0.2126878056100052 -0.6274049696002996; 0.6215549010875392 -0.687230131775131 … -0.6383151237825954 -0.539961896326495], bias = [-0.051397661463423845, 0.06634863041075242, 0.12422707557268331, 0.09930384799038716, 0.07901432556671387, 0.0568041863430608, 0.21989164301371075, -0.05827512262801289, -0.25561354771477784, 0.1617954587925212, -0.22694261953468667, 0.18057477534474972, 0.006770250771220369, -0.07966762659639402, -0.06116534045189478, -0.17890771108911385]), layer_3 = (weight = [-0.3565077065615576 -0.08180504355052164 … -0.054823988259299454 0.04552066817259864], bias = [-0.20904641279735298])), p_system = [1.3777563339015033e-8, -0.6111654084716439]))
cb(result_univ, result_univ.minimum; makeplot = true)
false