Neural Stochastic Differential Equations With Method of Moments

With neural stochastic differential equations, there is once again a helper form neural_dmsde which can be used for the multiplicative noise case (consult the layers API documentation, or this full example using the layer function).

However, since there are far too many possible combinations for the API to support, often you will want to define neural differential equations for non-ODE systems from scratch. To get good performance for these systems, it is generally best to use TrackerAdjoint with non-mutating (out-of-place) forms. For example, the following defines a neural SDE with neural networks for both the drift and diffusion terms:

dudt(u, p, t) = model(u)
g(u, p, t) = model2(u)
prob = SDEProblem(dudt, g, x, tspan, nothing)

where model and model2 are different neural networks. The same can apply to a neural delay differential equation. Its out-of-place formulation is f(u,h,p,t). Thus, for example, if we want to define a neural delay differential equation which uses the history value at p.tau in the past, we can define:

dudt!(u, h, p, t) = model([u; h(t - p.tau)])
prob = DDEProblem(dudt_, u0, h, tspan, nothing)

First, let's build training data from the same example as the neural ODE:

using Plots, Statistics, ComponentArrays, Optimization, OptimizationOptimisers, DiffEqFlux,
      StochasticDiffEq, SciMLBase.EnsembleAnalysis, Random

u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.0f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)
0.0f0:0.03448276f0:1.0f0
function trueSDEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end

mp = Float32[0.2, 0.2]
function true_noise_func(du, u, p, t)
    du .= mp .* u
end

prob_truesde = SDEProblem(trueSDEfunc, true_noise_func, u0, tspan)
SDEProblem with uType Vector{Float32} and tType Float32. In-place: true
Non-trivial mass matrix: false
timespan: (0.0f0, 1.0f0)
u0: 2-element Vector{Float32}:
 2.0
 0.0

For our dataset, we will use DifferentialEquations.jl's parallel ensemble interface to generate data from the average of 10,000 runs of the SDE:

# Take a typical sample from the mean
ensemble_prob = EnsembleProblem(prob_truesde; safetycopy = false)
ensemble_sol = solve(ensemble_prob, SOSRI(); trajectories = 10000)
ensemble_sum = EnsembleSummary(ensemble_sol)

sde_data, sde_data_vars = Array.(timeseries_point_meanvar(ensemble_sol, tsteps))
(Float32[2.0 1.915187 … 0.0628465 0.13601488; 0.0 0.52857983 … -0.6696221 -0.62210685], Float32[0.0 0.07733994 … 1.343315 1.310389; 0.0 0.05674475 … 0.9710136 1.0094317])

Now we build a neural SDE. For simplicity, we will use the NeuralDSDE neural SDE with diagonal noise layer function:

drift_dudt = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
diffusion_dudt = Dense(2, 2)

neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI();
    saveat = tsteps, reltol = 1e-1, abstol = 1e-1)
ps, st = Lux.setup(Xoshiro(0), neuralsde)
ps = ComponentArray(ps)
ComponentVector{Float32}(drift = (layer_1 = Float32[], layer_2 = (weight = Float32[-1.8019577 1.5097173; -0.18273845 -0.4676411; … ; 0.37099916 -0.27108315; -0.34856588 -0.6062841], bias = Float32[-0.522484, -0.6805993, -0.21060704, 0.50937545, 0.33639288, 0.22010256, -0.12450862, 0.3884359, 0.5799375, 0.39842856  …  0.1040132, 0.009969078, -0.460674, 0.21031016, 0.5280858, 0.7054404, 0.0009628869, 0.40567473, 0.30830613, 0.17590544]), layer_3 = (weight = Float32[0.22905384 -0.23547108 … 0.033212326 0.13550478; 0.22466984 -0.14894177 … -0.19668292 0.10960526], bias = Float32[-0.026947042, -0.0370021])), diffusion = (weight = Float32[-0.7320429 -1.0360838; -0.86161304 -0.6590253], bias = Float32[-0.22268295, 0.123328425]))

Let's see what that looks like:

# Get the prediction using the correct initial condition
prediction0 = neuralsde(u0, ps, st)[1]

drift_model = StatefulLuxLayer{true}(drift_dudt, ps.drift, st.drift)
diffusion_model = StatefulLuxLayer{true}(diffusion_dudt, ps.diffusion, st.diffusion)

drift_(u, p, t) = drift_model(u, p.drift)
diffusion_(u, p, t) = diffusion_model(u, p.diffusion)

prob_neuralsde = SDEProblem(drift_, diffusion_, u0, (0.0f0, 1.2f0), ps)

ensemble_nprob = EnsembleProblem(prob_neuralsde; safetycopy = false)
ensemble_nsol = solve(ensemble_nprob, SOSRI(); trajectories = 100, saveat = tsteps)
ensemble_nsum = EnsembleSummary(ensemble_nsol)

plt1 = plot(ensemble_nsum; title = "Neural SDE: Before Training")
scatter!(plt1, tsteps, sde_data'; lw = 3)

scatter(tsteps, sde_data[1, :]; label = "data")
scatter!(tsteps, prediction0[1, :]; label = "prediction")
Example block output

Now just as with the neural ODE we define a loss function that calculates the mean and variance from n runs at each time point and uses the distance from the data values:

neuralsde_model = StatefulLuxLayer{true}(neuralsde, ps, st)

function predict_neuralsde(p, u = u0)
    return Array(neuralsde_model(u, p))
end

function loss_neuralsde(p; n = 100)
    u = repeat(reshape(u0, :, 1), 1, n)
    samples = predict_neuralsde(p, u)
    currmeans = mean(samples; dims = 2)
    currvars = var(samples; dims = 2, mean = currmeans)[:, 1, :]
    currmeans = currmeans[:, 1, :]
    loss = sum(abs2, sde_data - currmeans) + sum(abs2, sde_data_vars - currvars)
    global means = currmeans
    global vars = currvars
    return loss
end
loss_neuralsde (generic function with 1 method)
list_plots = []
iter = 0
u = repeat(reshape(u0, :, 1), 1, 100)
samples = predict_neuralsde(ps, u)
means = mean(samples; dims = 2)
vars = var(samples; dims = 2, mean = means)[:, 1, :]
means = means[:, 1, :]

# Callback function to observe training
callback = function (state, loss; doplot = false)
    global list_plots, iter, means, vars

    if iter == 0
        list_plots = []
    end
    iter += 1

    # loss against current data
    display(loss)

    # plot current prediction against data
    plt = Plots.scatter(tsteps, sde_data[1, :]; yerror = sde_data_vars[1, :],
        ylim = (-4.0, 8.0), label = "data")
    Plots.scatter!(plt, tsteps, means[1, :]; ribbon = vars[1, :], label = "prediction")
    push!(list_plots, plt)

    if doplot
        display(plt)
    end
    return false
end
#4 (generic function with 1 method)

Now we train using this loss function. We can pre-train a little bit using a smaller n and then decrease it after it has had some time to adjust towards the right mean behavior:

opt = OptimizationOptimisers.Adam(0.025)

# First round of training with n = 10
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralsde(x; n = 10), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
result1 = Optimization.solve(optprob, opt; callback, maxiters = 100)
retcode: Default
u: ComponentVector{Float32}(drift = (layer_1 = Float32[], layer_2 = (weight = Float32[-1.6779859 1.4165629; -0.33156487 -0.5439335; … ; 0.050358985 -0.38419706; 0.55556107 -0.57838243], bias = Float32[-1.160309, -0.4074102, 0.47582808, 0.89017516, 0.83443093, -0.011645139, -0.5995196, 0.91472137, 0.22817892, -0.0015078115  …  -0.35750952, 0.7858873, -0.37112916, -0.022068001, 0.8362439, 0.90582407, 0.14017947, 0.022390246, 0.7121091, 0.061371006]), layer_3 = (weight = Float32[0.40636852 0.18310206 … -0.08077494 0.43516132; 0.10154045 -0.06260481 … -0.05778101 0.19457608], bias = Float32[-0.38893887, -0.25151238])), diffusion = (weight = Float32[-0.60984945 -0.9899276; -0.7552008 -0.51332664], bias = Float32[-0.07655256, 0.63993585]))

We resume the training with a larger n. (WARNING - this step is a couple of orders of magnitude longer than the previous one).

opt = OptimizationOptimisers.Adam(0.001)
optf2 = Optimization.OptimizationFunction((x, p) -> loss_neuralsde(x; n = 100), adtype)
optprob2 = Optimization.OptimizationProblem(optf2, result1.u)
result2 = Optimization.solve(optprob2, opt; callback, maxiters = 100)
retcode: Default
u: ComponentVector{Float32}(drift = (layer_1 = Float32[], layer_2 = (weight = Float32[-1.6782936 1.4112031; -0.3338748 -0.5714686; … ; 0.051938478 -0.4012088; 0.60652053 -0.55832696], bias = Float32[-1.1821795, -0.397618, 0.4887639, 0.86925215, 0.8150587, -0.0017856952, -0.6054083, 0.91810524, 0.20122133, -0.06707374  …  -0.37322676, 0.78796846, -0.31392416, -0.025836881, 0.8139375, 0.9161467, 0.13629347, 0.02250975, 0.69989026, 0.06868548]), layer_3 = (weight = Float32[0.40444708 0.18717717 … -0.07280053 0.43602303; 0.08947639 -0.07023244 … -0.07662381 0.20345025], bias = Float32[-0.388024, -0.24672307])), diffusion = (weight = Float32[-0.6000333 -0.98117167; -0.75443006 -0.503202], bias = Float32[-0.070265725, 0.6706514]))

And now we plot the solution to an ensemble of the trained neural SDE:

n = 1000
u = repeat(reshape(u0, :, 1), 1, n)
samples = predict_neuralsde(result2.u)
currmeans = mean(samples; dims = 2)
currvars = var(samples; dims = 2, mean = currmeans)[:, 1, :]
currmeans = currmeans[:, 1, :]

plt2 = Plots.scatter(tsteps, sde_data'; yerror = sde_data_vars', label = "data",
    title = "Neural SDE: After Training", xlabel = "Time")
plot!(plt2, tsteps, means'; lw = 8, ribbon = vars', label = "prediction")

plt = plot(plt1, plt2; layout = (2, 1))
Example block output

Try this with GPUs as well!