Benchmarks
Note on benchmarking and getting the best performance out of the SciML stack's adjoints
From our recent papers, it's clear that EnzymeVJP
is the fastest, especially when the program is set up to be fully non-allocating mutating functions. Thus for all benchmarking, especially with PDEs, this should be done. Neural network libraries don't make use of mutation effectively except for SimpleChains.jl, so we recommend creating a neural ODE / universal ODE with ZygoteVJP
and Lux first, but then check the correctness by moving the implementation over to SimpleChains and if possible EnzymeVJP
. This can be an order of magnitude improvement (or more) in many situations over all the previous benchmarks using Zygote and Lux, and thus it's highly recommended in scenarios that require performance.
Vs Torchdiffeq 1 million and less ODEs
A raw ODE solver benchmark showcases >30x performance advantage for DifferentialEquations.jl for ODEs ranging in size from 3 to nearly 1 million.
Vs Torchdiffeq on neural ODE training
A training benchmark using the spiral ODE from the original neural ODE paper demonstrates a 100x performance advantage for DiffEqFlux in training neural ODEs.
Vs torchsde on small SDEs
Using the code from torchsde's README, we demonstrated a >70,000x performance advantage over torchsde. Further benchmarking is planned, but was found to be computationally infeasible at this time.
A bunch of adjoint choices on neural ODEs
Quick summary:
BacksolveAdjoint
can be the fastest (but use with caution!); about 25% faster- Using
ZygoteVJP
is faster than other vjp choices for larger neural networks ReverseDiffVJP(compile = true)
works well for small Lux neural networks
using OrdinaryDiffEq, Lux, SciMLSensitivity, Zygote, BenchmarkTools, Random, ComponentArrays
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u .^ 3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
Random.seed!(100)
for sensealg in (InterpolatingAdjoint(autojacvec = ZygoteVJP()),
InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)),
BacksolveAdjoint(autojacvec = ReverseDiffVJP(true)),
BacksolveAdjoint(autojacvec = ZygoteVJP()),
BacksolveAdjoint(autojacvec = ReverseDiffVJP(false)),
BacksolveAdjoint(autojacvec = TrackerVJP()),
QuadratureAdjoint(autojacvec = ReverseDiffVJP(true)),
TrackerAdjoint())
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps,
sensealg = sensealg)
ps, st = Lux.setup(Random.default_rng(), prob_neuralode)
ps = ComponentArray(ps)
loss_neuralode = function (u0, p, st)
pred = Array(first(prob_neuralode(u0, p, st)))
loss = sum(abs2, ode_data .- pred)
return loss
end
t = @belapsed Zygote.gradient($loss_neuralode, $u0, $ps, $st)
println("$(sensealg) took $(t)s")
end
# InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}(ZygoteVJP(false), false, false) took 0.029134224s
# InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), false, false) took 0.001657377s
# BacksolveAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), true, false) took 0.002477057s
# BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}(ZygoteVJP(false), true, false) took 0.031533335s
# BacksolveAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}}(ReverseDiffVJP{false}(), true, false) took 0.004605386s
# BacksolveAdjoint{0, true, Val{:central}, TrackerVJP}(TrackerVJP(false), true, false) took 0.044568018s
# QuadratureAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), 1.0e-6, 0.001) took 0.002489559s
# TrackerAdjoint() took 0.003759097s