Neural Ordinary Differential Equations

A neural ODE is an ODE where a neural network defines its derivative function. For example, with the multilayer perceptron neural network Lux.Chain(Lux.Dense(2, 50, tanh), Lux.Dense(50, 2)), we can define a differential equation which is u' = NN(u). This is done simply by the NeuralODE struct. Let's take a look at an example.

Copy-Pasteable Code

Before getting to the explanation, here's some code to start with. We will follow a full explanation of the definition and training process:

using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
      OptimizationOptimisers, Random, Plots

rng = Xoshiro(0)
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end

prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps))

dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)

function predict_neuralode(p)
    Array(prob_neuralode(u0, p, st)[1])
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss
end

# Do not plot by default for the documentation
# Users should change doplot=true to see the plots callbacks
function callback(state, l; doplot = false)
    println(l)
    # plot current prediction against data
    if doplot
        pred = predict_neuralode(state.u)
        plt = scatter(tsteps, ode_data[1, :]; label = "data")
        scatter!(plt, tsteps, pred[1, :]; label = "prediction")
        display(plot(plt))
    end
    return false
end

pinit = ComponentArray(p)
callback((; u = pinit), loss_neuralode(pinit); doplot = true)

# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

result_neuralode = Optimization.solve(
    optprob, OptimizationOptimisers.Adam(0.05); callback = callback, maxiters = 300)

optprob2 = remake(optprob; u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2, Optim.BFGS(; initial_stepnorm = 0.01); callback, allow_f_increases = false)

callback((; u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true)
false

Neural ODE

Explanation

Let's get a time series array from a spiral ODE to train against.

using ComponentArrays, Lux, DiffEqFlux, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
      OptimizationOptimisers, Random, Plots

rng = Xoshiro(0)
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end

prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps))
2×30 Matrix{Float32}:
 2.0  1.9465    1.74178  1.23837  0.577127  …  1.40688   1.37023   1.29214
 0.0  0.798832  1.46473  1.80877  1.86465      0.451377  0.728699  0.972102

Now let's define a neural network with a NeuralODE layer. First, we define the layer. Here we're going to use Lux.Chain, which is a suitable neural network structure for NeuralODEs with separate handling of state variables:

dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
NeuralODE(
    model = Chain(
        layer_1 = WrappedFunction(#1),
        layer_2 = Dense(2 => 50, tanh),  # 150 parameters
        layer_3 = Dense(50 => 2),       # 102 parameters
    ),
)         # Total: 252 parameters,
          #        plus 0 states.

Note that we can directly use Chains from Lux.jl as well, for example:

dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))

In our model, we used the x -> x.^3 assumption in the model. By incorporating structure into our equations, we can reduce the required size and training time for the neural network, but a good guess needs to be known!

From here we build a loss function around it. The NeuralODE has an optional second argument for new parameters, which we will use to change the neural network iteratively in our training loop. We will use the L2 loss of the network's output against the time series data:

function predict_neuralode(p)
    Array(prob_neuralode(u0, p, st)[1])
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss
end
loss_neuralode (generic function with 1 method)

We define a callback function. In this example, we set doplot=false because otherwise it would show every step and overflow the documentation, but for your use case set doplot=true to see a live animation of the training process!

# Callback function to observe training
callback = function (state, l; doplot = false)
    println(l)
    # plot current prediction against data
    if doplot
        pred = predict_neuralode(state.u)
        plt = scatter(tsteps, ode_data[1, :]; label = "data")
        scatter!(plt, tsteps, pred[1, :]; label = "prediction")
        display(plot(plt))
    end
    return false
end

pinit = ComponentArray(p)
callback((; u = pinit), loss_neuralode(pinit))
false

We then train the neural network to learn the ODE.

Here we showcase starting the optimization with Adam to more quickly find a minimum, and then honing in on the minimum by using LBFGS. By using the two together, we can fit the neural ODE in 9 seconds! (Note, the timing commented out the plotting). You can easily incorporate the procedure below to set up custom optimization problems. For more information on the usage of Optimization.jl, please consult this documentation.

The x and p variables in the optimization function are different from x and p above. The optimization function runs over the space of parameters of the original problem, so x_optimization == p_original.

# Train using the Adam optimizer
adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

result_neuralode = Optimization.solve(
    optprob, OptimizationOptimisers.Adam(0.05); callback = callback, maxiters = 300)
retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-2.025178 1.4889137; -0.48845485 -0.22329545; … ; -0.20798619 0.0036160722; 0.54904604 -0.86559695], bias = Float32[-0.6270363, -0.820645, -0.5007736, 1.4128623, -0.39423355, 0.83272225, 0.82724094, 0.0027166214, 0.9395437, 0.7458477  …  0.23887838, -0.95385593, -1.5580357, 0.54830873, -0.86555207, 0.7341561, -0.086343355, 0.85364974, 0.28163964, 0.4055838]), layer_3 = (weight = Float32[-0.25172383 0.7503794 … -0.07587621 0.61020863; -0.038686562 -0.071954705 … -0.11840795 -0.03603409], bias = Float32[-0.76250345, 0.19005477]))

We then complete the training using a different optimizer, starting from where Adam stopped. We do allow_f_increases=false to make the optimization automatically halt when near the minimum.

# Retrain using the LBFGS optimizer
optprob2 = remake(optprob; u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(optprob2, Optim.BFGS(; initial_stepnorm = 0.01);
    callback = callback, allow_f_increases = false)
retcode: Success
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-2.020806 1.4950883; -0.46105754 -0.26112327; … ; -0.20876782 -0.01732303; 0.5120535 -0.8888397], bias = Float32[-0.6287585, -0.8229726, -0.5061863, 1.434119, -0.43867564, 0.8668595, 0.86666393, -0.0001465012, 0.98262066, 0.7479757  …  0.2578901, -0.9798032, -1.6046683, 0.54930216, -0.9088959, 0.7315262, -0.06679955, 0.878536, 0.2872703, 0.42273775]), layer_3 = (weight = Float32[-0.24912125 0.7675786 … -0.072181456 0.6391928; -0.014506301 -0.06052106 … -0.18200254 -0.048533753], bias = Float32[-0.73121375, 0.34071696]))

And then we use the callback with doplot=true to see the final plot:

callback((; u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true)
Example block output