Neural Differential Equation Layer Functions

The following layers are helper functions for easily building neural differential equation architectures in the currently most efficient way. As demonstrated in the tutorials, they do not have to be used since automatic differentiation will just work over solve, but these cover common use cases and choose what's known to be the optimal mode of AD for the respective equation type.

DiffEqFlux.NeuralODEType
NeuralODE(model, tspan, alg = nothing, args...; kwargs...)

Constructs a continuous-time recurrant neural network, also known as a neural ordinary differential equation (neural ODE), with a fast gradient calculation via adjoints [1]. At a high level this corresponds to solving the forward differential equation, using a second differential equation that propagates the derivatives of the loss backwards in time.

Arguments:

  • model: A Flux.Chain or Lux.AbstractLuxLayer neural network that defines the ̇x.
  • tspan: The timespan to be solved on.
  • alg: The algorithm used to solve the ODE. Defaults to nothing, i.e. the default algorithm from DifferentialEquations.jl.
  • sensealg: The choice of differentiation algorithm used in the backpropogation. Defaults to an adjoint method. See the Local Sensitivity Analysis documentation for more details.
  • kwargs: Additional arguments splatted to the ODE solver. See the Common Solver Arguments documentation for more details.

References:

[1] Pontryagin, Lev Semenovich. Mathematical theory of optimal processes. CRC press, 1987.

source
DiffEqFlux.NeuralDSDEType
NeuralDSDE(drift, diffusion, tspan, alg = nothing, args...; sensealg = TrackerAdjoint(),
    kwargs...)

Constructs a neural stochastic differential equation (neural SDE) with diagonal noise.

Arguments:

  • drift: A Flux.Chain or Lux.AbstractLuxLayer neural network that defines the drift function.
  • diffusion: A Flux.Chain or Lux.AbstractLuxLayer neural network that defines the diffusion function. Should output a vector of the same size as the input.
  • tspan: The timespan to be solved on.
  • alg: The algorithm used to solve the ODE. Defaults to nothing, i.e. the default algorithm from DifferentialEquations.jl.
  • sensealg: The choice of differentiation algorithm used in the backpropogation.
  • kwargs: Additional arguments splatted to the ODE solver. See the Common Solver Arguments documentation for more details.
source
DiffEqFlux.NeuralSDEType
NeuralSDE(drift, diffusion, tspan, nbrown, alg = nothing, args...;
    sensealg=TrackerAdjoint(), kwargs...)

Constructs a neural stochastic differential equation (neural SDE).

Arguments:

  • drift: A Flux.Chain or Lux.AbstractLuxLayer neural network that defines the drift function.
  • diffusion: A Flux.Chain or Lux.AbstractLuxLayer neural network that defines the diffusion function. Should output a matrix that is nbrown x size(x, 1).
  • tspan: The timespan to be solved on.
  • nbrown: The number of Brownian processes.
  • alg: The algorithm used to solve the ODE. Defaults to nothing, i.e. the default algorithm from DifferentialEquations.jl.
  • sensealg: The choice of differentiation algorithm used in the backpropogation.
  • kwargs: Additional arguments splatted to the ODE solver. See the Common Solver Arguments documentation for more details.
source
DiffEqFlux.NeuralCDDEType
NeuralCDDE(model, tspan, hist, lags, alg = nothing, args...;
    sensealg = TrackerAdjoint(), kwargs...)

Constructs a neural delay differential equation (neural DDE) with constant delays.

Arguments:

  • model: A Flux.Chain or Lux.AbstractLuxLayer neural network that defines the derivative function. Should take an input of size [x; x(t - lag_1); ...; x(t - lag_n)] and produce and output shaped like x.
  • tspan: The timespan to be solved on.
  • hist: Defines the history function h(u, p, t) for values before the start of the integration. Note that u is supposed to be used to return a value that matches the size of u.
  • lags: Defines the lagged values that should be utilized in the neural network.
  • alg: The algorithm used to solve the ODE. Defaults to nothing, i.e. the default algorithm from DifferentialEquations.jl.
  • sensealg: The choice of differentiation algorithm used in the backpropogation. Defaults to using reverse-mode automatic differentiation via Tracker.jl
  • kwargs: Additional arguments splatted to the ODE solver. See the Common Solver Arguments documentation for more details.
source
DiffEqFlux.NeuralDAEType
NeuralDAE(model, constraints_model, tspan, args...; differential_vars = nothing,
    sensealg = TrackerAdjoint(), kwargs...)

Constructs a neural differential-algebraic equation (neural DAE).

Arguments:

  • model: A Flux.Chain or Lux.AbstractLuxLayer neural network that defines the derivative function. Should take an input of size x and produce the residual of f(dx,x,t) for only the differential variables.
  • constraints_model: A function constraints_model(u,p,t) for the fixed constraints to impose on the algebraic equations.
  • tspan: The timespan to be solved on.
  • alg: The algorithm used to solve the ODE. Defaults to nothing, i.e. the default algorithm from DifferentialEquations.jl.
  • sensealg: The choice of differentiation algorithm used in the backpropogation. Defaults to using reverse-mode automatic differentiation via Tracker.jl
  • kwargs: Additional arguments splatted to the ODE solver. See the Common Solver Arguments documentation for more details.
source
DiffEqFlux.NeuralODEMMType
NeuralODEMM(model, constraints_model, tspan, mass_matrix, alg = nothing, args...;
    sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()), kwargs...)

Constructs a physically-constrained continuous-time recurrant neural network, also known as a neural differential-algebraic equation (neural DAE), with a mass matrix and a fast gradient calculation via adjoints [1]. The mass matrix formulation is:

\[Mu' = f(u,p,t)\]

where M is semi-explicit, i.e. singular with zeros for rows corresponding to the constraint equations.

Arguments:

  • model: A Flux.Chain or Lux.AbstractLuxLayer neural network that defines the ̇f(u,p,t)
  • constraints_model: A function constraints_model(u,p,t) for the fixed constraints to impose on the algebraic equations.
  • tspan: The timespan to be solved on.
  • mass_matrix: The mass matrix associated with the DAE.
  • alg: The algorithm used to solve the ODE. Defaults to nothing, i.e. the default algorithm from DifferentialEquations.jl. This method requires an implicit ODE solver compatible with singular mass matrices. Consult the DAE solvers documentation for more details.
  • sensealg: The choice of differentiation algorithm used in the backpropogation. Defaults to an adjoint method. See the Local Sensitivity Analysis documentation for more details.
  • kwargs: Additional arguments splatted to the ODE solver. See the Common Solver Arguments documentation for more details.
source
DiffEqFlux.AugmentedNDELayerFunction
AugmentedNDELayer(nde, adim::Int)

Constructs an Augmented Neural Differential Equation Layer.

Arguments:

  • nde: Any Neural Differential Equation Layer.
  • adim: The number of dimensions the initial conditions should be lifted.

References:

[1] Dupont, Emilien, Arnaud Doucet, and Yee Whye Teh. "Augmented neural ODEs." In Proceedings of the 33rd International Conference on Neural Information Processing Systems, pp. 3140-3150. 2019.

source

Helper Layer Functions

DiffEqFlux.DimMoverType
DimMover(from, to)

Constructs a Dimension Mover Layer.

We can have Lux's conventional order (data, channel, batch) by using it as the last layer of AbstractLuxLayer to swap the batch-index and the time-index of the Neural DE's output considering that each time point is a channel.

source