DeepEquilibriumNetworks: (Fast) Deep Equlibrium Networks

Latest Docs Stable Docs CI codecov ColPrac: Contributor's Guide on Collaborative Practices for Community Packages SciML Code Style Package Downloads

DeepEquilibriumNetworks.jl is a framework built on top of DifferentialEquations.jl and Lux.jl enabling the efficient training and inference for Deep Equilibrium Networks (Infinitely Deep Neural Networks).

Installation

] add DeepEquilibriumNetworks

Quickstart

import DeepEquilibriumNetworks as DEQs
import Lux
import Random
import Zygote

seed = 0
rng = Random.default_rng()
Random.seed!(rng, seed)

model = Lux.Chain(Lux.Dense(2, 2),
                  DEQs.DeepEquilibriumNetwork(Lux.Parallel(+, Lux.Dense(2, 2; bias=false),
                                                           Lux.Dense(2, 2; bias=false)),
                                              DEQs.ContinuousDEQSolver(; abstol=0.1f0,
                                                                       reltol=0.1f0,
                                                                       abstol_termination=0.1f0,
                                                                       reltol_termination=0.1f0)))

ps, st = gpu.(Lux.setup(rng, model))
x = gpu(rand(rng, Float32, 2, 1))
y = gpu(rand(rng, Float32, 2, 1))

gs = Zygote.gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1]

Citation

If you are using this project for research or other academic purposes consider citing our paper:

@misc{pal2022mixing,
  title={Mixing Implicit and Explicit Deep Learning with Skip DEQs and Infinite Time Neural
         ODEs (Continuous DEQs)}, 
  author={Avik Pal and Alan Edelman and Christopher Rackauckas},
  year={2022},
  eprint={2201.12240},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}

For specific algorithms, check the respective documentations and cite the corresponding papers.