Delay Differential Equations

Other differential equation problem types from DifferentialEquations.jl are supported. For example, we can build a layer with a delay differential equation like:

using OrdinaryDiffEq, Optimization, SciMLSensitivity, OptimizationPolyalgorithms,
      DelayDiffEq

# Define the same LV equation, but including a delay parameter
function delay_lotka_volterra!(du, u, h, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = dx = (α - β * y) * h(p, t - 0.1)[1]
    du[2] = dy = (δ * x - γ) * y
end

# Initial parameters
p = [2.2, 1.0, 2.0, 0.4]

# Define a vector containing delays for each variable (although only the first
# one is used)
h(p, t) = ones(eltype(p), 2)

# Initial conditions
u0 = [1.0, 1.0]

# Define the problem as a delay differential equation
prob_dde = DDEProblem(delay_lotka_volterra!, u0, h, (0.0, 10.0),
    constant_lags = [0.1])

function predict_dde(p)
    return Array(solve(prob_dde, MethodOfSteps(Tsit5()),
        u0 = u0, p = p, saveat = 0.1, sensealg = ReverseDiffAdjoint()))
end

loss_dde(p) = sum(abs2, x - 1 for x in predict_dde(p))

using Plots
callback = function (state, l; doplot = false)
    display(loss_dde(state.u))
    doplot &&
        display(plot(
            solve(remake(prob_dde, p = state.u), MethodOfSteps(Tsit5()), saveat = 0.1),
            ylim = (0, 6)))
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_dde(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result_dde = Optimization.solve(optprob, PolyOpt(), maxiters = 300, callback = callback)
retcode: Success
u: 4-element Vector{Float64}:
 1.5247375542690187
 1.5247375542644221
 0.7311657359742096
 0.7311657359760502

Notice that we chose sensealg = ReverseDiffAdjoint() to utilize the ReverseDiff.jl reverse-mode to handle the delay differential equation.

We define a callback to display the solution at the current parameters for each step of the training:

using Plots
callback = function (state, l; doplot = false)
    display(loss_dde(state.u))
    doplot &&
        display(plot(
            solve(remake(prob_dde, p = state.u), MethodOfSteps(Tsit5()), saveat = 0.1),
            ylim = (0, 6)))
    return false
end
#8 (generic function with 1 method)

We use Optimization.solve to optimize the parameters for our loss function:

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_dde(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result_dde = Optimization.solve(optprob, PolyOpt(), callback = callback)
retcode: Success
u: 4-element Vector{Float64}:
 1.5247375542690187
 1.5247375542644221
 0.7311657359742096
 0.7311657359760502