Solving Optimal Control Problems with Universal Differential Equations
Here we will solve a classic optimal control problem with a universal differential equation. Let
\[x^{′′} = u^3(t)\]
where we want to optimize our controller u(t) such that the following is minimized:
\[L(\theta) = \sum_i \Vert 4 - x(t_i) \Vert + 2 \Vert x^\prime(t_i) \Vert + \Vert u(t_i) \Vert\]
where $i$ is measured on (0,8) at 0.01 intervals. To do this, we rewrite the ODE in first order form:
\[\begin{aligned} x^\prime &= v \\ v^′ &= u^3(t) \\ \end{aligned}\]
and thus
\[L(\theta) = \sum_i \Vert 4 - x(t_i) \Vert + 2 \Vert v(t_i) \Vert + \Vert u(t_i) \Vert\]
is our loss function on the first order system. We thus choose a neural network form for $u$ and optimize the equation with respect to this loss. Note that we will first reduce control cost (the last term) by 10x in order to bump the network out of a local minimum. This looks like:
using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationNLopt,
OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random
rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)
ann = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(rng, ann)
p = ComponentArray(ps)
θ, ax = getdata(p), getaxes(p)
function dxdt_(dx, x, p, t)
ps = ComponentArray(p, ax)
x1, x2 = x
dx[1] = x[2]
dx[2] = first(ann([t], ps, st))[1]^3
end
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
prob = ODEProblem(dxdt_, x0, tspan, θ)
solve(prob, Vern9(), abstol = 1e-10, reltol = 1e-10)
function predict_adjoint(θ)
Array(solve(prob, Vern9(), p = θ, saveat = ts,
sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true))))
end
function loss_adjoint(θ)
x = predict_adjoint(θ)
ps = ComponentArray(θ, ax)
mean(abs2, 4.0 .- x[1, :]) + 2mean(abs2, x[2, :]) +
mean(abs2, [first(first(ann([t], ps, st))) for t in ts]) / 10
end
l = loss_adjoint(θ)
cb = function (θ, l; doplot = true)
println(l)
ps = ComponentArray(θ, ax)
if doplot
p = plot(solve(remake(prob, p = θ), Tsit5(), saveat = 0.01), ylim = (-6, 6), lw = 3)
plot!(p, ts, [first(first(ann([t], ps, st))) for t in ts], label = "u(t)", lw = 3)
display(p)
end
return false
end
# Display the ODE with the current parameter values.
cb(θ, l)
# Setup and run the optimization
loss1 = loss_adjoint(θ)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)
optprob = Optimization.OptimizationProblem(optf, θ)
res1 = Optimization.solve(optprob, Adam(0.01), callback = cb, maxiters = 100)
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2, NLopt.LD_LBFGS(), callback = cb, maxiters = 100)retcode: MaxIters
u: 1153-element Vector{Float64}:
-0.5497451078433414
1.939771271114207
-1.1428662806896397
-0.46461705578251444
0.018604964200195768
-0.06483718018380129
0.6692363442889305
1.6587963476448941
0.2543632975156579
0.6822867847531855
⋮
-0.5477243683061416
-0.11166078392052617
-0.14050595249818215
0.24136733609155925
0.34911939286429017
-0.57003483623194
-0.8014240960816431
-0.6510411688832113
1.6127527382782405Now that the system is in a better behaved part of parameter space, we return to the original loss function to finish the optimization:
function loss_adjoint(θ)
x = predict_adjoint(θ)
ps = ComponentArray(θ, ax)
mean(abs2, 4.0 .- x[1, :]) + 2mean(abs2, x[2, :]) +
mean(abs2, [first(first(ann([t], ps, st))) for t in ts])
end
optf3 = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)
optprob3 = Optimization.OptimizationProblem(optf3, res2.u)
res3 = Optimization.solve(optprob3, NLopt.LD_LBFGS(), maxiters = 100)retcode: MaxIters
u: 1153-element Vector{Float64}:
-0.44930910753303144
1.9749811870864644
-1.1949401324690636
-0.4901100174331857
-0.17388422749027446
-0.126542524224296
0.705274084158971
1.693234508188875
0.09566709355679578
0.649159149894226
⋮
-0.5844630181964663
-0.1586750666391161
-0.17832898246079737
0.232165724076851
0.48761806385585793
-0.5930866916941891
-0.8254407591743629
-0.6584931473041558
1.635285274244193Now let's see what we received:
l = loss_adjoint(res3.u)
cb(res3.u, l)
p = plot(solve(remake(prob, p = res3.u), Tsit5(), saveat = 0.01), ylim = (-6, 6), lw = 3)
plot!(p, ts, [first(first(ann([t], ComponentArray(res3.u, ax), st))) for t in ts], label = "u(t)", lw = 3)