Neural ODEs on GPUs

Note that the differential equation solvers will run on the GPU if the initial condition is a GPU array. Thus, for example, we can define a neural ODE manually that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU).

For a detailed discussion on how GPUs need to be setup refer to Lux Docs.

using OrdinaryDiffEq, Lux, LuxCUDA, SciMLSensitivity, ComponentArrays, Random
rng = Xoshiro(0)

const cdev = cpu_device()
const gdev = gpu_device()

model = Chain(Dense(2, 50, tanh), Dense(50, 2))
ps, st = Lux.setup(rng, model)
ps = ps |> ComponentArray |> gdev
st = st |> gdev
dudt(u, p, t) = model(u, p, st)[1]

# Simulation interval and intermediary points
tspan = (0.0f0, 10.0f0)
tsteps = 0.0f0:1.0f-1:10.0f0

u0 = Float32[2.0; 0.0] |> gdev
prob_gpu = ODEProblem(dudt, u0, tspan, ps)

# Runs on a GPU
sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps)
retcode: Success
Interpolation: 1st order linear
t: 101-element Vector{Float32}:
  0.0
  0.1
  0.2
  0.3
  0.4
  0.5
  0.6
  0.7
  0.8
  0.9
  ⋮
  9.2
  9.3
  9.4
  9.5
  9.6
  9.7
  9.8
  9.9
 10.0
u: 101-element Vector{Vector{Float32}}:
 [2.0, 0.0]
 [1.7873144, 0.031452533]
 [1.5787977, 0.06454542]
 [1.3766682, 0.09886328]
 [1.1836987, 0.13350393]
 [1.0028385, 0.16698295]
 [0.83674806, 0.19752903]
 [0.6875931, 0.22379515]
 [0.556868, 0.245554]
 [0.44503433, 0.2636885]
 ⋮
 [1.8024267, 12.045058]
 [1.8131173, 12.216458]
 [1.8229246, 12.387675]
 [1.8318273, 12.558733]
 [1.8398049, 12.7296505]
 [1.8468374, 12.900447]
 [1.8529061, 13.07115]
 [1.8579924, 13.241777]
 [1.8620787, 13.412361]

Or we could directly use the neural ODE layer function, like:

using DiffEqFlux: NeuralODE
prob_neuralode_gpu = NeuralODE(model, tspan, Tsit5(); saveat = tsteps)
NeuralODE(
    model = Chain(
        layer_1 = Dense(2 => 50, tanh),           # 150 parameters
        layer_2 = Dense(50 => 2),                 # 102 parameters
    ),
)         # Total: 252 parameters,
          #        plus 0 states.

If one is using Lux.Chain, then the computation takes place on the GPU with f(x,p,st) if x, p and st are on the GPU. This commonly looks like:

dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))

u0 = Float32[2.0; 0.0] |> gdev
p, st = Lux.setup(rng, dudt2) |> gdev

dudt2_(u, p, t) = first(dudt2(u, p, st))

# Simulation interval and intermediary points
tspan = (0.0f0, 10.0f0)
tsteps = 0.0f0:1.0f-1:10.0f0

prob_gpu = ODEProblem(dudt2_, u0, tspan, p)

# Runs on a GPU
sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps)
retcode: Success
Interpolation: 1st order linear
t: 101-element Vector{Float32}:
  0.0
  0.1
  0.2
  0.3
  0.4
  0.5
  0.6
  0.7
  0.8
  0.9
  ⋮
  9.2
  9.3
  9.4
  9.5
  9.6
  9.7
  9.8
  9.9
 10.0
u: 101-element Vector{Vector{Float32}}:
 [2.0, 0.0]
 [2.0003986, 0.08488123]
 [2.0007863, 0.16975328]
 [2.0011175, 0.25462347]
 [2.001314, 0.33951008]
 [2.0012627, 0.4244526]
 [2.0008123, 0.5095264]
 [1.9997551, 0.59488827]
 [1.9977883, 0.6808464]
 [1.9945598, 0.76779586]
 ⋮
 [-0.4119342, 0.3561046]
 [-0.4301134, 0.38099936]
 [-0.44993347, 0.4038766]
 [-0.4713037, 0.4245298]
 [-0.49403465, 0.44280446]
 [-0.51783746, 0.45859838]
 [-0.5423251, 0.47186184]
 [-0.5670111, 0.48259723]
 [-0.59131217, 0.49086034]

or via the NeuralODE struct:

prob_neuralode_gpu = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
prob_neuralode_gpu(u0, p, st)
(SciMLBase.ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, Nothing, SciMLBase.ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_3::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}(Vector{Float32}[[2.0, 0.0], [2.0003986, 0.08488123], [2.0007863, 0.16975328], [2.0011175, 0.25462347], [2.001314, 0.33951008], [2.0012627, 0.4244526], [2.0008123, 0.5095264], [1.9997551, 0.59488827], [1.9977883, 0.6808464], [1.9945598, 0.76779586]  …  [-0.3953882, 0.32945195], [-0.4119342, 0.3561046], [-0.4301134, 0.38099936], [-0.44993347, 0.4038766], [-0.4713037, 0.4245298], [-0.49403465, 0.44280446], [-0.51783746, 0.45859838], [-0.5423251, 0.47186184], [-0.5670111, 0.48259723], [-0.59131217, 0.49086034]], nothing, nothing, Float32[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9  …  9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10.0], Vector{Vector{Float32}}[[[2.0, 0.0]]], nothing, SciMLBase.ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_3::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}, SciMLBase.StandardODEProblem}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}(StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}(Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = WrappedFunction(#2), layer_2 = Dense(2 => 50, tanh), layer_3 = Dense(50 => 2)), nothing), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), nothing, Val{true}())), LinearAlgebra.UniformScaling{Bool}(true), nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), Float32[2.0, 0.0], (0.0f0, 10.0f0), (layer_1 = NamedTuple(), layer_2 = (weight = Float32[-1.6285712 -0.7106906; 1.8252167 -0.15980828; … ; -0.46829274 0.3371172; -1.2275605 -1.2200289], bias = Float32[0.10678334, 0.04854786, -0.3036146, -0.67875606, 0.19386894, 0.5310734, -0.07256524, -0.18512556, -0.071528256, 0.37226313  …  -0.17194617, 0.21797575, -0.29705614, 0.19832478, 0.4268401, 0.43911377, 0.46748847, 0.19240205, 0.38869342, 0.29241425]), layer_3 = (weight = Float32[0.0004401929 0.031856854 … -0.056129575 0.2373046; -0.07581123 -0.13199675 … 0.20389898 -0.18144646], bias = Float32[-0.101620436, -0.1026898])), Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), OrdinaryDiffEqCore.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5ConstantCache, Nothing}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}(StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}(Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = WrappedFunction(#2), layer_2 = Dense(2 => 50, tanh), layer_3 = Dense(50 => 2)), nothing), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), nothing, Val{true}())), LinearAlgebra.UniformScaling{Bool}(true), nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), Vector{Float32}[[2.0, 0.0], [2.0003986, 0.08488123], [2.0007863, 0.16975328], [2.0011175, 0.25462347], [2.001314, 0.33951008], [2.0012627, 0.4244526], [2.0008123, 0.5095264], [1.9997551, 0.59488827], [1.9977883, 0.6808464], [1.9945598, 0.76779586]  …  [-0.3953882, 0.32945195], [-0.4119342, 0.3561046], [-0.4301134, 0.38099936], [-0.44993347, 0.4038766], [-0.4713037, 0.4245298], [-0.49403465, 0.44280446], [-0.51783746, 0.45859838], [-0.5423251, 0.47186184], [-0.5670111, 0.48259723], [-0.59131217, 0.49086034]], Float32[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9  …  9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10.0], Vector{Vector{Float32}}[[[2.0, 0.0]]], nothing, false, OrdinaryDiffEqTsit5.Tsit5ConstantCache(), nothing, false), false, 0, SciMLBase.DEStats(201, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 2, 0.0), nothing, SciMLBase.ReturnCode.Success, nothing, nothing, nothing), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))

Neural ODE Example

Here is the full neural ODE example. Note that we use the gpu_device function so that the same code works on CPUs and GPUs, dependent on using LuxCUDA.

using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq, Plots, LuxCUDA,
      SciMLSensitivity, Random, ComponentArrays
import DiffEqFlux: NeuralODE
const cdev = cpu_device()
const gdev = gpu_device()

CUDA.allowscalar(false) # Makes sure no slow operations are occurring

#rng for Lux.setup
rng = Xoshiro(0)
# Generate Data
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)
function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Make the data into a GPU-based array if the user has a GPU
ode_data = solve(prob_trueode, Tsit5(); saveat = tsteps)
ode_data = Array(ode_data) |> gdev

dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
u0 = Float32[2.0; 0.0] |> gdev
p, st = Lux.setup(rng, dudt2)
p = p |> ComponentArray |> gdev
st = st |> gdev

prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)

predict_neuralode(p) = reduce(hcat, first(prob_neuralode(u0, p, st)).u)
function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss
end
# Callback function to observe training
list_plots = []
iter = 0
callback = function (state, l; doplot = false)
    p = state.u
    global list_plots, iter
    pred = predict_neuralode(p)
    if iter == 0
        list_plots = []
    end
    iter += 1
    display(l)
    # plot current prediction against data
    plt = scatter(tsteps, Array(ode_data[1, :]); label = "data")
    scatter!(plt, tsteps, Array(pred[1, :]); label = "prediction")
    push!(list_plots, plt)
    if doplot
        display(plot(plt))
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result_neuralode = Optimization.solve(
    optprob, OptimizationOptimisers.Adam(0.05); callback, maxiters = 300)
retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-1.9956415 1.505737; -0.47191903 -0.25449324; … ; -0.20570713 0.01822312; 0.56123346 -0.85311043], bias = Float32[-0.6303439, -0.77096194, -0.5568976, 1.4353292, -0.37929276, 0.84450424, 0.8785953, -0.02970068, 0.96417844, 0.7853037  …  0.24361013, -1.0012139, -1.4132881, 0.5324092, -0.8246897, 0.66072077, -0.10391649, 1.0552049, 0.24918991, 0.4061778]), layer_3 = (weight = Float32[-0.2874329 0.71571547 … -0.03869169 0.61091924; -0.043230742 -0.052570976 … -0.12665555 -0.027335625], bias = Float32[-0.7654097, 0.17761244]))