Neural ODEs on GPUs
Note that the differential equation solvers will run on the GPU if the initial condition is a GPU array. Thus, for example, we can define a neural ODE manually that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU).
For a detailed discussion on how GPUs need to be setup refer to Lux Docs.
using OrdinaryDiffEq, Lux, CUDA, SciMLSensitivity, ComponentArrays, Random
rng = Xoshiro(0)
const cdev = cpu_device()
const gdev = gpu_device()
model = Chain(Dense(2, 50, tanh), Dense(50, 2))
ps, st = Lux.setup(rng, model)
ps = ps |> ComponentArray |> gdev
st = st |> gdev
dudt(u, p, t) = model(u, p, st)[1]
# Simulation interval and intermediary points
tspan = (0.0f0, 10.0f0)
tsteps = 0.0f0:1.0f-1:10.0f0
u0 = Float32[2.0; 0.0] |> gdev
prob_gpu = ODEProblem(dudt, u0, tspan, ps)
# Runs on a GPU
sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps)retcode: Success
Interpolation: 1st order linear
t: 101-element Vector{Float32}:
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
⋮
9.2
9.3
9.4
9.5
9.6
9.7
9.8
9.9
10.0
u: 101-element Vector{Vector{Float32}}:
[2.0, 0.0]
[1.7873144, 0.031452525]
[1.5787979, 0.06454544]
[1.3766685, 0.09886332]
[1.1836994, 0.13350452]
[1.002838, 0.16698374]
[0.8367471, 0.19752812]
[0.6875936, 0.22379687]
[0.5568645, 0.24556386]
[0.44504493, 0.26368698]
⋮
[1.8026863, 12.045193]
[1.8133812, 12.216578]
[1.8231944, 12.387779]
[1.8321064, 12.558832]
[1.8400973, 12.7297535]
[1.847144, 12.90055]
[1.8532271, 13.071252]
[1.858328, 13.241879]
[1.8624291, 13.412461]Or we could directly use the neural ODE layer function, like:
using DiffEqFlux: NeuralODE
prob_neuralode_gpu = NeuralODE(model, tspan, Tsit5(); saveat = tsteps)NeuralODE(
model = Chain(
layer_1 = Dense(2 => 50, tanh), # 150 parameters
layer_2 = Dense(50 => 2), # 102 parameters
),
) # Total: 252 parameters,
# plus 0 states.If one is using Lux.Chain, then the computation takes place on the GPU with f(x,p,st) if x, p and st are on the GPU. This commonly looks like:
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
u0 = Float32[2.0; 0.0] |> gdev
p, st = Lux.setup(rng, dudt2) |> gdev
dudt2_(u, p, t) = first(dudt2(u, p, st))
# Simulation interval and intermediary points
tspan = (0.0f0, 10.0f0)
tsteps = 0.0f0:1.0f-1:10.0f0
prob_gpu = ODEProblem(dudt2_, u0, tspan, p)
# Runs on a GPU
sol_gpu = solve(prob_gpu, Tsit5(); saveat = tsteps)retcode: Success
Interpolation: 1st order linear
t: 101-element Vector{Float32}:
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
⋮
9.2
9.3
9.4
9.5
9.6
9.7
9.8
9.9
10.0
u: 101-element Vector{Vector{Float32}}:
[2.0, 0.0]
[2.0003986, 0.0848813]
[2.0007863, 0.16975331]
[2.0011175, 0.2546234]
[2.001314, 0.33950984]
[2.0012617, 0.42445397]
[2.0008118, 0.50952744]
[1.9997566, 0.5948856]
[1.9978081, 0.6808158]
[1.9945081, 0.7678875]
⋮
[-0.39173004, 0.3561513]
[-0.41075367, 0.38182774]
[-0.431586, 0.40559503]
[-0.45411325, 0.4272154]
[-0.47812077, 0.44647375]
[-0.5033074, 0.4632105]
[-0.5292859, 0.47732174]
[-0.55558145, 0.4887584]
[-0.5816351, 0.49752787]or via the NeuralODE struct:
prob_neuralode_gpu = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
prob_neuralode_gpu(u0, p, st)(SciMLBase.ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, Nothing, SciMLBase.ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_3::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}(Vector{Float32}[[2.0, 0.0], [2.0003986, 0.0848813], [2.0007863, 0.16975331], [2.0011175, 0.2546234], [2.001314, 0.33950984], [2.0012617, 0.42445397], [2.0008118, 0.50952744], [1.9997566, 0.5948856], [1.9978081, 0.6808158], [1.9945081, 0.7678875] … [-0.37445134, 0.32878417], [-0.39173004, 0.3561513], [-0.41075367, 0.38182774], [-0.431586, 0.40559503], [-0.45411325, 0.4272154], [-0.47812077, 0.44647375], [-0.5033074, 0.4632105], [-0.5292859, 0.47732174], [-0.55558145, 0.4887584], [-0.5816351, 0.49752787]], nothing, nothing, Float32[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 … 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10.0], Vector{Vector{Float32}}[[[2.0, 0.0]]], nothing, SciMLBase.ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, false, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_3::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}, SciMLBase.StandardODEProblem}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}(StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}(Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = WrappedFunction(#2), layer_2 = Dense(2 => 50, tanh), layer_3 = Dense(50 => 2)), nothing), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), nothing, Val{true}())), LinearAlgebra.UniformScaling{Bool}(true), nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), Float32[2.0, 0.0], (0.0f0, 10.0f0), (layer_1 = NamedTuple(), layer_2 = (weight = Float32[-1.6285712 -0.7106906; 1.8252167 -0.15980828; … ; -0.46829274 0.3371172; -1.2275605 -1.2200289], bias = Float32[0.10678334, 0.04854786, -0.3036146, -0.67875606, 0.19386894, 0.5310734, -0.07256524, -0.18512556, -0.071528256, 0.37226313 … -0.17194617, 0.21797575, -0.29705614, 0.19832478, 0.4268401, 0.43911377, 0.46748847, 0.19240205, 0.38869342, 0.29241425]), layer_3 = (weight = Float32[0.0004401929 0.031856854 … -0.056129575 0.2373046; -0.07581123 -0.13199675 … 0.20389898 -0.18144646], bias = Float32[-0.101620436, -0.1026898])), Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), OrdinaryDiffEqTsit5.Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}(OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, static(false)), OrdinaryDiffEqCore.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5ConstantCache, Nothing}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(DiffEqFlux.var"#dudt#dudt##0"{StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}}(StatefulLuxLayer{Val{true}, Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}, layer_3::@NamedTuple{}}}(Chain{@NamedTuple{layer_1::WrappedFunction{Main.var"#2#3"}, layer_2::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = WrappedFunction(#2), layer_2 = Dense(2 => 50, tanh), layer_3 = Dense(50 => 2)), nothing), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()), nothing, Val{true}())), LinearAlgebra.UniformScaling{Bool}(true), nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), Vector{Float32}[[2.0, 0.0], [2.0003986, 0.0848813], [2.0007863, 0.16975331], [2.0011175, 0.2546234], [2.001314, 0.33950984], [2.0012617, 0.42445397], [2.0008118, 0.50952744], [1.9997566, 0.5948856], [1.9978081, 0.6808158], [1.9945081, 0.7678875] … [-0.37445134, 0.32878417], [-0.39173004, 0.3561513], [-0.41075367, 0.38182774], [-0.431586, 0.40559503], [-0.45411325, 0.4272154], [-0.47812077, 0.44647375], [-0.5033074, 0.4632105], [-0.5292859, 0.47732174], [-0.55558145, 0.4887584], [-0.5816351, 0.49752787]], Float32[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 … 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10.0], Vector{Vector{Float32}}[[[2.0, 0.0]]], nothing, false, OrdinaryDiffEqTsit5.Tsit5ConstantCache(), nothing, false), false, 0, SciMLBase.DEStats(201, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 2, 0.0), nothing, SciMLBase.ReturnCode.Success, nothing, nothing, nothing), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))Neural ODE Example
Here is the full neural ODE example. Note that we use the gpu_device function so that the same code works on CPUs and GPUs, dependent on using CUDA.
using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq, Plots, CUDA,
SciMLSensitivity, Random, ComponentArrays
import DiffEqFlux: NeuralODE
const cdev = cpu_device()
const gdev = gpu_device()
CUDA.allowscalar(false) # Makes sure no slow operations are occurring
#rng for Lux.setup
rng = Xoshiro(0)
# Generate Data
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u .^ 3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Make the data into a GPU-based array if the user has a GPU
ode_data = solve(prob_trueode, Tsit5(); saveat = tsteps)
ode_data = Array(ode_data) |> gdev
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
u0 = Float32[2.0; 0.0] |> gdev
p, st = Lux.setup(rng, dudt2)
p = p |> ComponentArray |> gdev
st = st |> gdev
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
predict_neuralode(p) = reduce(hcat, first(prob_neuralode(u0, p, st)).u)
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss
end
# Callback function to observe training
list_plots = []
iter = 0
callback = function (state, l; doplot = false)
p = state.u
global list_plots, iter
pred = predict_neuralode(p)
if iter == 0
list_plots = []
end
iter += 1
display(l)
# plot current prediction against data
plt = scatter(tsteps, Array(ode_data[1, :]); label = "data")
scatter!(plt, tsteps, Array(pred[1, :]); label = "prediction")
push!(list_plots, plt)
if doplot
display(plot(plt))
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
result_neuralode = Optimization.solve(
optprob, OptimizationOptimisers.Adam(0.05); callback, maxiters = 300)retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-2.0201523 1.5047237; -0.48799333 -0.22094278; … ; -0.20681679 0.01847002; 0.5569128 -0.8740125], bias = Float32[-0.62723947, -0.81310475, -0.51000154, 1.3997053, -0.38143492, 0.816305, 0.8593733, -0.0018838106, 0.91819584, 0.7497174 … 0.25200874, -0.9336435, -1.5443151, 0.5430574, -0.84973997, 0.7290552, -0.08725739, 0.6089473, 0.2810185, 0.40009674]), layer_3 = (weight = Float32[-0.2583687 0.7560768 … -0.08102027 0.6142381; -0.036943514 -0.06070116 … -0.11565053 -0.03477804], bias = Float32[-0.7600118, 0.18344054]))