Solving Optimal Control Problems with Universal Differential Equations

Here we will solve a classic optimal control problem with a universal differential equation. Let

\[x^{′′} = u^3(t)\]

where we want to optimize our controller u(t) such that the following is minimized:

\[L(\theta) = \sum_i \Vert 4 - x(t_i) \Vert + 2 \Vert x^\prime(t_i) \Vert + \Vert u(t_i) \Vert\]

where $i$ is measured on (0,8) at 0.01 intervals. To do this, we rewrite the ODE in first order form:

\[\begin{aligned} x^\prime &= v \\ v^′ &= u^3(t) \\ \end{aligned}\]

and thus

\[L(\theta) = \sum_i \Vert 4 - x(t_i) \Vert + 2 \Vert v(t_i) \Vert + \Vert u(t_i) \Vert\]

is our loss function on the first order system. We thus choose a neural network form for $u$ and optimize the equation with respect to this loss. Note that we will first reduce control cost (the last term) by 10x in order to bump the network out of a local minimum. This looks like:

using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationNLopt,
    OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random

rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)

ann = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(rng, ann)
p = ComponentArray(ps)

θ, ax = getdata(p), getaxes(p)

function dxdt_(dx, x, p, t)
    ps = ComponentArray(p, ax)
    x1, x2 = x
    dx[1] = x[2]
    dx[2] = first(ann([t], ps, st))[1]^3
end
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
prob = ODEProblem(dxdt_, x0, tspan, θ)
solve(prob, Vern9(), abstol = 1e-10, reltol = 1e-10)

function predict_adjoint(θ)
    Array(solve(prob, Vern9(), p = θ, saveat = ts,
        sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true))))
end
function loss_adjoint(θ)
    x = predict_adjoint(θ)
    ps = ComponentArray(θ, ax)
    mean(abs2, 4.0 .- x[1, :]) + 2mean(abs2, x[2, :]) +
    mean(abs2, [first(first(ann([t], ps, st))) for t in ts]) / 10
end

l = loss_adjoint(θ)
cb = function (θ, l; doplot = true)
    println(l)

    ps = ComponentArray(θ, ax)

    if doplot
        p = plot(solve(remake(prob, p = θ), Tsit5(), saveat = 0.01), ylim = (-6, 6), lw = 3)
        plot!(p, ts, [first(first(ann([t], ps, st))) for t in ts], label = "u(t)", lw = 3)
        display(p)
    end

    return false
end

# Display the ODE with the current parameter values.

cb(θ, l)

# Setup and run the optimization

loss1 = loss_adjoint(θ)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)

optprob = Optimization.OptimizationProblem(optf, θ)
res1 = Optimization.solve(optprob, Adam(0.01), callback = cb, maxiters = 100)

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2, NLopt.LD_LBFGS(), callback = cb, maxiters = 100)
retcode: MaxIters
u: 1153-element Vector{Float64}:
 -0.5497451078433414
  1.939771271114207
 -1.1428662806896397
 -0.46461705578251444
  0.018604964200195768
 -0.06483718018380129
  0.6692363442889305
  1.6587963476448941
  0.2543632975156579
  0.6822867847531855
  ⋮
 -0.5477243683061416
 -0.11166078392052617
 -0.14050595249818215
  0.24136733609155925
  0.34911939286429017
 -0.57003483623194
 -0.8014240960816431
 -0.6510411688832113
  1.6127527382782405

Now that the system is in a better behaved part of parameter space, we return to the original loss function to finish the optimization:

function loss_adjoint(θ)
    x = predict_adjoint(θ)
    ps = ComponentArray(θ, ax)
    mean(abs2, 4.0 .- x[1, :]) + 2mean(abs2, x[2, :]) +
    mean(abs2, [first(first(ann([t], ps, st))) for t in ts])
end
optf3 = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)

optprob3 = Optimization.OptimizationProblem(optf3, res2.u)
res3 = Optimization.solve(optprob3, NLopt.LD_LBFGS(), maxiters = 100)
retcode: MaxIters
u: 1153-element Vector{Float64}:
 -0.44930910753303144
  1.9749811870864644
 -1.1949401324690636
 -0.4901100174331857
 -0.17388422749027446
 -0.126542524224296
  0.705274084158971
  1.693234508188875
  0.09566709355679578
  0.649159149894226
  ⋮
 -0.5844630181964663
 -0.1586750666391161
 -0.17832898246079737
  0.232165724076851
  0.48761806385585793
 -0.5930866916941891
 -0.8254407591743629
 -0.6584931473041558
  1.635285274244193

Now let's see what we received:

l = loss_adjoint(res3.u)
cb(res3.u, l)
p = plot(solve(remake(prob, p = res3.u), Tsit5(), saveat = 0.01), ylim = (-6, 6), lw = 3)
plot!(p, ts, [first(first(ann([t], ComponentArray(res3.u, ax), st))) for t in ts], label = "u(t)", lw = 3)
Example block output