Deep Equilibrium Models

(Bai et al., 2019) introduced Discrete Deep Equilibrium Models which drives a Discrete Dynamical System to its steady-state. (Pal et al., 2022) extends this framework to Continuous Dynamical Systems which converge to the steady-stable in a more stable fashion. For a detailed discussion refer to (Pal et al., 2022).

To construct a continuous DEQ, any ODE solver compatible with DifferentialEquations.jl API can be passed as the solver. To construct a discrete DEQ, any root finding algorithm compatible with NonlinearSolve.jl API can be passed as the solver.

Choosing a Solver

Root Finding Algorithms

Using Root Finding Algorithms give fast convergence when possible, but these methods also tend to be unstable. If you must use a root finding algorithm, we recommend using:

  1. NewtonRaphson or TrustRegion for small models
  2. LimitedMemoryBroyden for large Deep Learning applications (with well-conditioned Jacobians)
  3. NewtonRaphson(; linsolve = KrylovJL_GMRES()) for cases when Broyden methods fail

Note that Krylov Methods rely on efficient VJPs which are not available for all Lux models. If you think this is causing a performance regression, please open an issue in Lux.jl.

ODE Solvers

Using ODE Solvers give slower convergence, but are more stable. We generally recommend these methods over root finding algorithms. If you use implicit ODE solvers, remember to use Krylov linear solvers, see OrdinaryDiffEq.jl documentation for these. For most cases, we recommend:

  1. VCAB3() for high tolerance problems
  2. Tsit5() for high tolerance problems where VCAB3() fails
  3. In all other cases, follow the recommendation given in OrdinaryDiffEq.jl documentation

Sensitivity Analysis

  1. For MultiScaleNeuralODE, we default to GaussAdjoint(; autojacvec = ZygoteVJP()). A faster alternative would be BacksolveAdjoint(; autojacvec = ZygoteVJP()) but there are stability concerns for using that. Follow the recommendation given in SciMLSensitivity.jl documentation.
  2. For Steady State Problems, we default to SteadyStateAdjoint(; linsolve = SimpleGMRES(; blocksize, linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3))). This default will perform poorly on small models. It is recommended to pass sensealg = SteadyStateAdjoint() or sensealg = SteadyStateAdjoint(; linsolve = LUFactorization()) for small models.

Standard Models

DeepEquilibriumNetworks.DeepEquilibriumNetworkType
DeepEquilibriumNetwork(model, solver; init = missing, jacobian_regularization=nothing,
    problem_type::Type=SteadyStateProblem{false}, kwargs...)

Deep Equilibrium Network as proposed in (Bai et al., 2019) and (Pal et al., 2022).

Arguments

  • model: Neural Network.
  • solver: Solver for the rootfinding problem. ODE Solvers and Nonlinear Solvers are both supported.

Keyword Arguments

  • init: Initial Condition for the rootfinding problem. If nothing, the initial condition is set to zero(x). If missing, the initial condition is set to WrappedFunction(zero). In other cases the initial condition is set to init(x, ps, st).
  • jacobian_regularization: Must be one of nothing, AutoForwardDiff, AutoFiniteDiff or AutoZygote.
  • problem_type: Provides a way to simulate a Vanilla Neural ODE by setting the problem_type to ODEProblem. By default, the problem type is set to SteadyStateProblem.
  • kwargs: Additional Parameters that are directly passed to SciMLBase.solve.

Example

julia> model = DeepEquilibriumNetwork(
           Parallel(+, Dense(2, 2; use_bias=false), Dense(2, 2; use_bias=false)),
           VCABM3(); verbose=false);

julia> rng = Xoshiro(0);

julia> ps, st = Lux.setup(rng, model);

julia> size(first(model(ones(Float32, 2, 1), ps, st)))
(2, 1)

See also: SkipDeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork.

source

MultiScale Models

DeepEquilibriumNetworks.MultiScaleDeepEquilibriumNetworkFunction
MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
    post_fuse_layer::Union{Nothing, Tuple}, solver,
    scales::NTuple{N, NTuple{L, Int64}}; kwargs...)

Multi Scale Deep Equilibrium Network as proposed in (Bai et al., 2020).

Arguments

  • main_layers: Tuple of Neural Networks. Each Neural Network is applied to the corresponding scale.
  • mapping_layers: Matrix of Neural Networks. Each Neural Network is applied to the corresponding scale and the corresponding layer.
  • post_fuse_layer: Neural Network applied to the fused output of the main layers.
  • solver: Solver for the rootfinding problem. ODE Solvers and Nonlinear Solvers are both supported.
  • scales: Scales of the Multi Scale DEQ. Each scale is a tuple of integers. The length of the tuple is the number of layers in the corresponding main layer.

For keyword arguments, see DeepEquilibriumNetwork.

Example

julia> main_layers = (
           Parallel(+, Dense(4 => 4, tanh; use_bias=false), Dense(4 => 4, tanh; use_bias=false)),
           Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh));

julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh);
                         Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
                         Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
                         Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()];

julia> model = MultiScaleDeepEquilibriumNetwork(
           main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,)));

julia> rng = Xoshiro(0);

julia> ps, st = Lux.setup(rng, model);

julia> x = rand(rng, Float32, 4, 12);

julia> size.(first(model(x, ps, st)))
((4, 12), (3, 12), (2, 12), (1, 12))
source
DeepEquilibriumNetworks.MultiScaleSkipDeepEquilibriumNetworkFunction
MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
    post_fuse_layer::Union{Nothing, Tuple}, [init = nothing,] solver,
    scales::NTuple{N, NTuple{L, Int64}}; kwargs...)

Skip Multi Scale Deep Equilibrium Network as proposed in (Pal et al., 2022). Alias which creates a MultiScaleDeepEquilibriumNetwork with init kwarg set to passed value.

If init is not passed, it creates a MultiScale Regularized Deep Equilibrium Network.

source

Solution

DeepEquilibriumNetworks.DeepEquilibriumSolutionType
DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe, solution)

Stores the solution of a DeepEquilibriumNetwork and its variants.

Fields

  • z_star: Steady-State or the value reached due to maxiters
  • u0: Initial Condition
  • residual: Difference of the $z^*$ and $f(z^*, x)$
  • jacobian_loss: Jacobian Stabilization Loss (see individual networks to see how it can be computed)
  • nfe: Number of Function Evaluations
  • original: Original Internal Solution
source