Deep Equilibrium Models

Standard Models

DeepEquilibriumNetworks.DeepEquilibriumNetworkType
DeepEquilibriumNetwork(model, solver; jacobian_regularization::Bool=false,
                       sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...)

Deep Equilibrium Network as proposed in baideep2019.

Arguments

Example

import DeepEquilibriumNetworks as DEQs
import Lux
import Random
import OrdinaryDiffEq

model = DEQs.DeepEquilibriumNetwork(Lux.Parallel(+, Lux.Dense(2, 2; bias=false),
                                                 Lux.Dense(2, 2; bias=false)),
                                    DEQs.ContinuousDEQSolver(OrdinaryDiffEq.VCABM3();
                                                             abstol=0.01f0, reltol=0.01f0))

rng = Random.default_rng()
ps, st = Lux.setup(rng, model)

model(rand(rng, Float32, 2, 1), ps, st)

See also: SkipDeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork.

DeepEquilibriumNetworks.SkipDeepEquilibriumNetworkType
SkipDeepEquilibriumNetwork(model, shortcut, solver; jacobian_regularization::Bool=false,
                           sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...)

Skip Deep Equilibrium Network as proposed in pal2022mixing

Arguments

Example

import DeepEquilibriumNetworks as DEQs
import Lux
import Random
import OrdinaryDiffEq

## SkipDEQ
model = DEQs.SkipDeepEquilibriumNetwork(Lux.Parallel(+, Lux.Dense(2, 2; bias=false),
                                                     Lux.Dense(2, 2; bias=false)),
                                        Lux.Dense(2, 2),
                                        DEQs.ContinuousDEQSolver(OrdinaryDiffEq.VCABM3();
                                                                 abstol=0.01f0,
                                                                 reltol=0.01f0))

rng = Random.default_rng()
ps, st = Lux.setup(rng, model)

model(rand(rng, Float32, 2, 1), ps, st)

## SkipDEQV2
model = DEQs.SkipDeepEquilibriumNetwork(Lux.Parallel(+, Lux.Dense(2, 2; bias=false),
                                                     Lux.Dense(2, 2; bias=false)), nothing,
                                        DEQs.ContinuousDEQSolver(OrdinaryDiffEq.VCABM3();
                                                                 abstol=0.01f0,
                                                                 reltol=0.01f0))

rng = Random.default_rng()
ps, st = Lux.setup(rng, model)

model(rand(rng, Float32, 2, 1), ps, st)

See also: DeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork

MultiScale Models

DeepEquilibriumNetworks.MultiScaleDeepEquilibriumNetworkType
MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
                                 post_fuse_layer::Union{Nothing,Tuple}, solver, scales;
                                 sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10),
                                 kwargs...)

Multiscale Deep Equilibrium Network as proposed in baimultiscale2020

Arguments

  • main_layers: Tuple of Neural Networks. The first network needs to take a tuple of 2 arrays, the other ones only take 1 input.
  • mapping_layers: Matrix of Neural Networks. The $(i, j)^{th}$ network takes the output of $i^{th}$ main_layer and passes it to the $j^{th}$ main_layer.
  • post_fuse_layer: Tuple of Neural Networks. Each of the scales are passed through this layer.
  • solver: Solver for the optimization problem (See: ContinuousDEQSolver & DiscreteDEQSolver).
  • scales: Output scales.
  • sensealg: See DeepEquilibriumAdjoint.
  • kwargs: Additional Parameters that are directly passed to SciMLBase.solve.

Example

import DeepEquilibriumNetworks as DEQs
import Lux
import OrdinaryDiffEq
import Random

main_layers = (Lux.Parallel(+, Lux.Dense(4, 4, tanh), Lux.Dense(4, 4, tanh)),
               Lux.Dense(3, 3, tanh), Lux.Dense(2, 2, tanh), Lux.Dense(1, 1, tanh))

mapping_layers = [Lux.NoOpLayer() Lux.Dense(4, 3, tanh) Lux.Dense(4, 2, tanh) Lux.Dense(4, 1, tanh);
                  Lux.Dense(3, 4, tanh) Lux.NoOpLayer() Lux.Dense(3, 2, tanh) Lux.Dense(3, 1, tanh);
                  Lux.Dense(2, 4, tanh) Lux.Dense(2, 3, tanh) Lux.NoOpLayer() Lux.Dense(2, 1, tanh);
                  Lux.Dense(1, 4, tanh) Lux.Dense(1, 3, tanh) Lux.Dense(1, 2, tanh) Lux.NoOpLayer()]

solver = DEQs.ContinuousDEQSolver(OrdinaryDiffEq.VCABM3(); abstol=0.01f0, reltol=0.01f0)

model = DEQs.MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, nothing, solver,
                                              ((4,), (3,), (2,), (1,)))

rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
x = rand(rng, Float32, 4, 1)

model(x, ps, st)

See also: DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork

DeepEquilibriumNetworks.MultiScaleSkipDeepEquilibriumNetworkType
MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
                                     post_fuse_layer::Union{Nothing,Tuple},
                                     shortcut_layers::Union{Nothing,Tuple}, solver,
                                     scales;
                                     sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10),
                                     kwargs...)

Multiscale Deep Equilibrium Network as proposed in baimultiscale2020 combined with Skip Deep Equilibrium Network as proposed in pal2022mixing.

Arguments

  • main_layers: Tuple of Neural Networks. The first network needs to take a tuple of 2 arrays, the other ones only take 1 input.
  • mapping_layers: Matrix of Neural Networks. The $(i, j)^{th}$ network takes the output of $i^{th}$ main_layer and passes it to the $j^{th}$ main_layer.
  • post_fuse_layer: Tuple of Neural Networks. Each of the scales are passed through this layer.
  • shortcut_layers: Shortcut for the network (pass nothing for SkipDEQV2).
  • solver: Solver for the optimization problem (See: ContinuousDEQSolver & DiscreteDEQSolver).
  • scales: Output scales.
  • sensealg: See DeepEquilibriumAdjoint.
  • kwargs: Additional Parameters that are directly passed to SciMLBase.solve.

Example

import DeepEquilibriumNetworks as DEQs
import Lux
import OrdinaryDiffEq
import Random

# MSkipDEQ
main_layers = (Lux.Parallel(+, Lux.Dense(4, 4, tanh), Lux.Dense(4, 4, tanh)),
               Lux.Dense(3, 3, tanh), Lux.Dense(2, 2, tanh), Lux.Dense(1, 1, tanh))

mapping_layers = [Lux.NoOpLayer() Lux.Dense(4, 3, tanh) Lux.Dense(4, 2, tanh) Lux.Dense(4, 1, tanh);
                  Lux.Dense(3, 4, tanh) Lux.NoOpLayer() Lux.Dense(3, 2, tanh) Lux.Dense(3, 1, tanh);
                  Lux.Dense(2, 4, tanh) Lux.Dense(2, 3, tanh) Lux.NoOpLayer() Lux.Dense(2, 1, tanh);
                  Lux.Dense(1, 4, tanh) Lux.Dense(1, 3, tanh) Lux.Dense(1, 2, tanh) Lux.NoOpLayer()]

solver = DEQs.ContinuousDEQSolver(OrdinaryDiffEq.VCABM3(); abstol=0.01f0, reltol=0.01f0)

shortcut_layers = (Lux.Dense(4, 4, tanh), Lux.Dense(4, 3, tanh), Lux.Dense(4, 2, tanh),
                   Lux.Dense(4, 1, tanh))

model = DEQs.MultiScaleSkipDeepEquilibriumNetwork(main_layers, mapping_layers, nothing,
                                                  shortcut_layers, solver,
                                                  ((4,), (3,), (2,), (1,)))

rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
x = rand(rng, Float32, 4, 2)

model(x, ps, st)

# MSkipDEQV2
model = DEQs.MultiScaleSkipDeepEquilibriumNetwork(main_layers, mapping_layers, nothing,
                                                  nothing, solver, ((4,), (3,), (2,), (1,)))

rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
x = rand(rng, Float32, 4, 2)

model(x, ps, st)

See also: DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork