Neural ODEs on GPUs
Note that the differential equation solvers will run on the GPU if the initial condition is a GPU array. Thus, for example, we can define a neural ODE manually that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU).
For a detailed discussion on how GPUs need to be setup refer to Lux Docs.
using OrdinaryDiffEq, Lux, LuxCUDA, SciMLSensitivity, ComponentArrays, Random
rng = Xoshiro(0)
const cdev = cpu_device()
const gdev = gpu_device()
model = Chain(Dense(2, 50, tanh), Dense(50, 2))
ps, st = Lux.setup(rng, model)
ps = ps |> ComponentArray |> gdev
st = st |> gdev
dudt(u, p, t) = model(u, p, st)[1]
# Simulation interval and intermediary points
tspan = (0.0f0, 10.0f0)
tsteps = 0.0f0:1.0f-1:10.0f0
u0 = Float32[2.0; 0.0] |> gdev
prob_gpu = ODEProblem(dudt, u0, tspan, ps)
# Runs on a GPU
sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps)retcode: Success
Interpolation: 1st order linear
t: 101-element Vector{Float32}:
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
⋮
9.2
9.3
9.4
9.5
9.6
9.7
9.8
9.9
10.0
u: 101-element Vector{Vector{Float32}}:
[2.0, 0.0]
[1.7873144, 0.031452533]
[1.5787977, 0.06454542]
[1.3766682, 0.09886328]
[1.1836987, 0.13350393]
[1.0028385, 0.16698295]
[0.83674806, 0.19752903]
[0.6875931, 0.22379515]
[0.556868, 0.245554]
[0.44503433, 0.2636885]
⋮
[1.8024267, 12.045058]
[1.8131173, 12.216458]
[1.8229246, 12.387675]
[1.8318273, 12.558733]
[1.8398049, 12.7296505]
[1.8468374, 12.900447]
[1.8529061, 13.07115]
[1.8579924, 13.241777]
[1.8620787, 13.412361]Or we could directly use the neural ODE layer function, like:
using DiffEqFlux: NeuralODE
prob_neuralode_gpu = NeuralODE(model, tspan, Tsit5(); saveat = tsteps)NeuralODE(
model = Chain(
layer_1 = Dense(2 => 50, tanh), # 150 parameters
layer_2 = Dense(50 => 2), # 102 parameters
),
) # Total: 252 parameters,
# plus 0 states.If one is using Lux.Chain, then the computation takes place on the GPU with f(x,p,st) if x, p and st are on the GPU. This commonly looks like:
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
u0 = Float32[2.0; 0.0] |> gdev
p, st = Lux.setup(rng, dudt2) |> gdev
dudt2_(u, p, t) = first(dudt2(u, p, st))
# Simulation interval and intermediary points
tspan = (0.0f0, 10.0f0)
tsteps = 0.0f0:1.0f-1:10.0f0
prob_gpu = ODEProblem(dudt2_, u0, tspan, p)
# Runs on a GPU
sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps)retcode: Success
Interpolation: 1st order linear
t: 101-element Vector{Float32}:
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
⋮
9.2
9.3
9.4
9.5
9.6
9.7
9.8
9.9
10.0
u: 101-element Vector{Vector{Float32}}:
[2.0, 0.0]
[2.0003986, 0.08488123]
[2.0007863, 0.16975328]
[2.0011175, 0.25462347]
[2.001314, 0.33951008]
[2.0012627, 0.4244526]
[2.0008123, 0.5095264]
[1.9997551, 0.59488827]
[1.9977883, 0.6808464]
[1.9945598, 0.76779586]
⋮
[-0.4119342, 0.3561046]
[-0.4301134, 0.38099936]
[-0.44993347, 0.4038766]
[-0.4713037, 0.4245298]
[-0.49403465, 0.44280446]
[-0.51783746, 0.45859838]
[-0.5423251, 0.47186184]
[-0.5670111, 0.48259723]
[-0.59131217, 0.49086034]or via the NeuralODE struct:
prob_neuralode_gpu = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
prob_neuralode_gpu(u0, p, st)(SciMLBase.ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, Nothing, SciMLBase.ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_3::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}(Vector{Float32}[[2.0, 0.0], [2.0003986, 0.08488123], [2.0007863, 0.16975328], [2.0011175, 0.25462347], [2.001314, 0.33951008], [2.0012627, 0.4244526], [2.0008123, 0.5095264], [1.9997551, 0.59488827], [1.9977883, 0.6808464], [1.9945598, 0.76779586] … [-0.3953882, 0.32945195], [-0.4119342, 0.3561046], [-0.4301134, 0.38099936], [-0.44993347, 0.4038766], [-0.4713037, 0.4245298], [-0.49403465, 0.44280446], [-0.51783746, 0.45859838], [-0.5423251, 0.47186184], [-0.5670111, 0.48259723], [-0.59131217, 0.49086034]], nothing, nothing, Float32[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 … 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10.0], Vector{Vector{Float32}}[[[2.0, 0.0]]], nothing, SciMLBase.ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_3::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}, SciMLBase.StandardODEProblem}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}(StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}(Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = WrappedFunction(#2), layer_2 = Dense(2 => 50, tanh), layer_3 = Dense(50 => 2)), nothing), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), nothing, Val{true}())), LinearAlgebra.UniformScaling{Bool}(true), nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), Float32[2.0, 0.0], (0.0f0, 10.0f0), (layer_1 = NamedTuple(), layer_2 = (weight = Float32[-1.6285712 -0.7106906; 1.8252167 -0.15980828; … ; -0.46829274 0.3371172; -1.2275605 -1.2200289], bias = Float32[0.10678334, 0.04854786, -0.3036146, -0.67875606, 0.19386894, 0.5310734, -0.07256524, -0.18512556, -0.071528256, 0.37226313 … -0.17194617, 0.21797575, -0.29705614, 0.19832478, 0.4268401, 0.43911377, 0.46748847, 0.19240205, 0.38869342, 0.29241425]), layer_3 = (weight = Float32[0.0004401929 0.031856854 … -0.056129575 0.2373046; -0.07581123 -0.13199675 … 0.20389898 -0.18144646], bias = Float32[-0.101620436, -0.1026898])), Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), OrdinaryDiffEqCore.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5ConstantCache, Nothing}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}(StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}(Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = WrappedFunction(#2), layer_2 = Dense(2 => 50, tanh), layer_3 = Dense(50 => 2)), nothing), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), nothing, Val{true}())), LinearAlgebra.UniformScaling{Bool}(true), nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), Vector{Float32}[[2.0, 0.0], [2.0003986, 0.08488123], [2.0007863, 0.16975328], [2.0011175, 0.25462347], [2.001314, 0.33951008], [2.0012627, 0.4244526], [2.0008123, 0.5095264], [1.9997551, 0.59488827], [1.9977883, 0.6808464], [1.9945598, 0.76779586] … [-0.3953882, 0.32945195], [-0.4119342, 0.3561046], [-0.4301134, 0.38099936], [-0.44993347, 0.4038766], [-0.4713037, 0.4245298], [-0.49403465, 0.44280446], [-0.51783746, 0.45859838], [-0.5423251, 0.47186184], [-0.5670111, 0.48259723], [-0.59131217, 0.49086034]], Float32[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 … 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10.0], Vector{Vector{Float32}}[[[2.0, 0.0]]], nothing, false, OrdinaryDiffEqTsit5.Tsit5ConstantCache(), nothing, false), false, 0, SciMLBase.DEStats(201, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 2, 0.0), nothing, SciMLBase.ReturnCode.Success, nothing, nothing, nothing), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))Neural ODE Example
Here is the full neural ODE example. Note that we use the gpu_device function so that the same code works on CPUs and GPUs, dependent on using LuxCUDA.
using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq, Plots, LuxCUDA,
SciMLSensitivity, Random, ComponentArrays
import DiffEqFlux: NeuralODE
const cdev = cpu_device()
const gdev = gpu_device()
CUDA.allowscalar(false) # Makes sure no slow operations are occurring
#rng for Lux.setup
rng = Xoshiro(0)
# Generate Data
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u .^ 3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Make the data into a GPU-based array if the user has a GPU
ode_data = solve(prob_trueode, Tsit5(); saveat = tsteps)
ode_data = Array(ode_data) |> gdev
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
u0 = Float32[2.0; 0.0] |> gdev
p, st = Lux.setup(rng, dudt2)
p = p |> ComponentArray |> gdev
st = st |> gdev
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
predict_neuralode(p) = reduce(hcat, first(prob_neuralode(u0, p, st)).u)
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss
end
# Callback function to observe training
list_plots = []
iter = 0
callback = function (state, l; doplot = false)
p = state.u
global list_plots, iter
pred = predict_neuralode(p)
if iter == 0
list_plots = []
end
iter += 1
display(l)
# plot current prediction against data
plt = scatter(tsteps, Array(ode_data[1, :]); label = "data")
scatter!(plt, tsteps, Array(pred[1, :]); label = "prediction")
push!(list_plots, plt)
if doplot
display(plot(plt))
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result_neuralode = Optimization.solve(
optprob, OptimizationOptimisers.Adam(0.05); callback, maxiters = 300)retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-1.9956415 1.505737; -0.47191903 -0.25449324; … ; -0.20570713 0.01822312; 0.56123346 -0.85311043], bias = Float32[-0.6303439, -0.77096194, -0.5568976, 1.4353292, -0.37929276, 0.84450424, 0.8785953, -0.02970068, 0.96417844, 0.7853037 … 0.24361013, -1.0012139, -1.4132881, 0.5324092, -0.8246897, 0.66072077, -0.10391649, 1.0552049, 0.24918991, 0.4061778]), layer_3 = (weight = Float32[-0.2874329 0.71571547 … -0.03869169 0.61091924; -0.043230742 -0.052570976 … -0.12665555 -0.027335625], bias = Float32[-0.7654097, 0.17761244]))