Batched Reductions for Lowering Peak Memory Requirements

Just as in the regular form of the DifferentialEquations.jl ensemble interface, a reduction function can be given to reduce between batches. Here we show an example of running 20 ODEs at a time, grabbing its value at the end, and reducing by summing all the values. This then allows for only saving the sum of the previous batches, boosting the trajectory count to an amount that is higher than would fit in memory, and only saving the summed values.

using OrdinaryDiffEq, DiffEqGPU, CUDA

seed = 100
using Random;
Random.seed!(seed);
ra = rand(100)

function f!(du, u, p, t)
    du[1] = 1.01 * u[1]
end

prob = ODEProblem(f!, [0.5], (0.0, 1.0))

function output_func(sol, i)
    last(sol), false
end

function prob_func(prob, i, repeat)
    remake(prob, u0 = ra[i] * prob.u0)
end

function reduction(u, batch, I)
    u .+ sum(batch), false
end

prob2 = EnsembleProblem(prob, prob_func = prob_func, output_func = output_func,
    reduction = reduction, u_init = Vector{eltype(prob.u0)}([0.0]))
sim4 = solve(prob2, Tsit5(), EnsembleGPUArray(CUDA.CUDABackend()), trajectories = 100,
    batch_size = 20)
EnsembleSolution Solution of length 1 with uType:
Float64