Universal Differential Equations for Neural Feedback Control
You can also mix a known differential equation and a neural differential equation, so that the parameters and the neural network are estimated simultaneously!
We will assume that we know the dynamics of the second equation (linear dynamics), and our goal is to find a neural network that is dependent on the current state of the dynamical system that will control the second equation to stay close to 1.
import Lux
import Optimization as OPT
import OptimizationPolyalgorithms as OPA
import ComponentArrays as CA
import SciMLSensitivity as SMS
import Zygote
import OrdinaryDiffEq as ODE
import Plots
import Random
rng = Random.default_rng()
u0 = [1.1]
tspan = (0.0, 25.0)
tsteps = 0.0:1.0:25.0
model_univ = Lux.Chain(Lux.Dense(2, 16, tanh), Lux.Dense(16, 16, tanh), Lux.Dense(16, 1))
ps, st = Lux.setup(Random.default_rng(), model_univ)
p_model = CA.ComponentArray(ps)
# Parameters of the second equation (linear dynamics)
p_system = [0.5, -0.5]
p_all = CA.ComponentArray(; p_model, p_system)
θ = CA.ComponentArray(; u0, p_all)
function dudt_univ!(du, u, p, t)
# Destructure the parameters
model_weights = p.p_model
α, β = p.p_system
# The neural network outputs a control taken by the system
# The system then produces an output
model_control, system_output = u
# Dynamics of the control and system
dmodel_control = first(model_univ(u, model_weights, st))[1]
dsystem_output = α * system_output + β * model_control
# Update in place
du[1] = dmodel_control
du[2] = dsystem_output
end
prob_univ = ODE.ODEProblem(dudt_univ!, [0.0, u0[1]], tspan, p_all)
sol_univ = ODE.solve(prob_univ, ODE.Tsit5(), abstol = 1e-8, reltol = 1e-6)
function predict_univ(θ)
return Array(ODE.solve(prob_univ, ODE.Tsit5(), u0 = [0.0, θ.u0[1]], p = θ.p_all,
sensealg = SMS.InterpolatingAdjoint(autojacvec = SMS.ReverseDiffVJP(true)),
saveat = tsteps))
end
loss_univ(θ) = sum(abs2, predict_univ(θ)[2, :] .- 1)
l = loss_univ(θ)
4.3677094960602313e11
list_plots = []
iter = 0
cb = function (state, l; makeplot = false)
global list_plots, iter
if iter == 0
list_plots = []
end
iter += 1
println(l)
if makeplot
plt = Plots.plot(predict_univ(state.u)', ylim = (0, 6))
push!(list_plots, plt)
display(plt)
end
return false
end
#1 (generic function with 1 method)
adtype = OPT.AutoZygote()
optf = OPT.OptimizationFunction((x, p) -> loss_univ(x), adtype)
optprob = OPT.OptimizationProblem(optf, θ)
result_univ = OPT.solve(optprob, OPA.PolyOpt(), callback = cb)
retcode: Success
u: ComponentVector{Float64}(u0 = [1.0000000008240493], p_all = (p_model = (layer_1 = (weight = [0.6911541849921956 -0.2485944183163356; 0.9683091203518902 0.4894536828702283; … ; -1.4879455881804553 -1.2808987463491261; -1.98016714775987 -1.4938549777529913], bias = [-0.6199675313812164, -0.6609773930054169, -0.003563990893127036, -0.3044260562311755, 0.06611838582698522, 0.5627765386712679, 0.41619759061119926, 0.06705518993010014, -0.11764497536910462, 0.5317878916587553, -0.6174164350332402, 0.3000953083846712, -0.5795439973162813, 0.2771328734265985, -0.13114832523820866, -0.6581484795825708]), layer_2 = (weight = [0.5273931739771462 0.0925528800787314 … 0.5034187487507479 0.41412739596232306; -0.57447390886256 -0.5206458565868975 … 0.5110404086991823 0.664563297412866; … ; 0.64186869710697 0.4550625920948574 … 0.5709866065904478 0.285888134387911; -0.5567659664482298 -0.12716026522671667 … -0.6594489166117824 -0.24170952037583096], bias = [0.15594731868014586, 0.1156958424971594, -0.0571454702023597, -0.11843737846535843, -0.21212426041196442, 0.10527715524601951, -0.023741253316599816, -0.12674585952280792, -0.16808780452963318, -0.06129017028272061, 0.1375396754677096, 0.04750150355384471, 0.1516258737426186, 0.05417235493208629, 0.0923453785397209, 0.1628865832801004]), layer_3 = (weight = [0.36694538171440394 0.12727478085771995 … 0.16328236540947216 -0.21781480185388175], bias = [0.0606335299101465])), p_system = [-1.2758249735338222e-8, -0.6201105801456387]))
cb(result_univ, result_univ.minimum; makeplot = true)
false