Training a Simple MNIST Classifier using Deep Equilibrium Models
We will train a simple Deep Equilibrium Model on MNIST. First we load a few packages.
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Random, Optimisers, Zygote, LinearSolve, Dates, Printf, Setfield, OneHotArrays
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs
using LuxCUDA # For NVIDIA GPU support
CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
true
Setup device functions from Lux. See GPU Management for more details.
const cdev = cpu_device()
const gdev = gpu_device()
(::MLDataDevices.CUDADevice{Nothing}) (generic function with 5 methods)
We can now construct our dataloader. We are using only limited part of the data for demonstration.
function loadmnist(batchsize, train_split)
N = 2500
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
# Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)
Construct the Lux Neural Network containing a DEQ layer.
function construct_model(solver; model_type::Symbol=:deq)
down = Chain(Conv((3, 3), 1 => 64, gelu; stride=1), GroupNorm(64, 64),
Conv((4, 4), 64 => 64; stride=2, pad=1))
# The input layer of the DEQ
deq_model = Chain(
Parallel(+,
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(),
init_weight=truncated_normal(; std=0.01), use_bias=false),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(),
init_weight=truncated_normal(; std=0.01), use_bias=false)),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(),
init_weight=truncated_normal(; std=0.01), use_bias=false))
if model_type === :skipdeq
init = Conv((3, 3), 64 => 64, gelu; stride=1, pad=SamePad())
elseif model_type === :regdeq
init = nothing
else
init = missing
end
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
linsolve_kwargs=(; maxiters=10), maxiters=10)
classifier = Chain(
GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))
model = Chain(; down, deq, classifier)
# For NVIDIA GPUs this directly generates the parameters on the GPU
rng = Random.default_rng() |> gdev
ps, st = Lux.setup(rng, model)
# Warmup the forward and backward passes
x = randn(rng, Float32, 28, 28, 1, 2)
y = onehotbatch(rand(Random.default_rng(), 0:9, 2), 0:9) |> gdev
@printf "[%s] warming up forward pass\n" string(now())
loss_function(model, ps, st, (x, y))
@printf "[%s] warming up backward pass\n" string(now())
Zygote.gradient(first ∘ loss_function, model, ps, st, (x, y))
@printf "[%s] warmup complete\n" string(now())
return model, ps, st
end
construct_model (generic function with 1 method)
Define some helper functions to train the model.
const logit_cross_entropy = CrossEntropyLoss(; logits=Val(true))
const mse_loss = MSELoss()
function loss_function(model, ps, st, (x, y))
ŷ, st = model(x, ps, st)
l1 = logit_cross_entropy(ŷ, y)
l2 = mse_loss(st.deq.solution.z_star, st.deq.solution.u0) # Add in some regularization
return l1 + eltype(l2)(0.01) * l2, st, (;)
end
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model(x, ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
function train_model(solver, model_type)
model, ps, st = construct_model(solver; model_type)
train_dataloader, test_dataloader = loadmnist(32, 0.8) |> gdev
tstate = Training.TrainState(model, ps, st, Adam(0.0005))
@printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))
@printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5))
for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader)
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), loss_function, (x, y), tstate)
if i % 10 == 1
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss
end
end
acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
@printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc
@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(0))
for epoch in 1:3
for (i, (x, y)) in enumerate(train_dataloader)
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), loss_function, (x, y), tstate)
if i % 10 == 1
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss
end
end
acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
@printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
end
@printf "[%s] Training complete.\n" string(now())
return model, ps, tstate.states
end
train_model (generic function with 1 method)
Now we can train our model. First we will train a Discrete DEQ, which effectively means pass in a root finding algorithm. Typically most packages lack good nonlinear solvers, and end up using solvers like Broyden
, but we can simply slap in any of the fancy solvers from NonlinearSolve.jl. Here we will use Newton-Krylov Method:
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq);
[2024-09-10T14:58:50.387] warming up forward pass
[2024-09-10T14:59:27.548] warming up backward pass
[2024-09-10T15:00:32.510] warmup complete
[2024-09-10T15:00:35.290] Training Model: regdeq with Solver: GeneralizedFirstOrderAlgorithm
[2024-09-10T15:00:35.529] Pretrain with unrolling to a depth of 5
[2024-09-10T15:01:18.411] Pretraining Batch: [ 1/ 63] Loss: 2.29847
[2024-09-10T15:01:19.315] Pretraining Batch: [ 11/ 63] Loss: 2.19594
[2024-09-10T15:01:19.786] Pretraining Batch: [ 21/ 63] Loss: 2.30583
[2024-09-10T15:01:19.964] Pretraining Batch: [ 31/ 63] Loss: 2.21255
[2024-09-10T15:01:20.140] Pretraining Batch: [ 41/ 63] Loss: 2.17548
[2024-09-10T15:01:20.290] Pretraining Batch: [ 51/ 63] Loss: 2.19290
[2024-09-10T15:01:20.457] Pretraining Batch: [ 61/ 63] Loss: 2.14694
[2024-09-10T15:01:20.606] Pretraining Batch: [ 1/ 63] Loss: 2.15881
[2024-09-10T15:01:20.780] Pretraining Batch: [ 11/ 63] Loss: 2.18139
[2024-09-10T15:01:20.930] Pretraining Batch: [ 21/ 63] Loss: 2.14976
[2024-09-10T15:01:21.088] Pretraining Batch: [ 31/ 63] Loss: 2.09837
[2024-09-10T15:01:21.239] Pretraining Batch: [ 41/ 63] Loss: 2.19074
[2024-09-10T15:01:21.397] Pretraining Batch: [ 51/ 63] Loss: 2.09138
[2024-09-10T15:01:21.565] Pretraining Batch: [ 61/ 63] Loss: 2.03783
[2024-09-10T15:01:31.011] Pretraining complete. Accuracy: 66.60000%
[2024-09-10T15:01:35.177] Epoch: [1/3] Batch: [ 1/ 63] Loss: 1.99373
[2024-09-10T15:01:36.339] Epoch: [1/3] Batch: [ 11/ 63] Loss: 2.02998
[2024-09-10T15:01:36.802] Epoch: [1/3] Batch: [ 21/ 63] Loss: 2.06365
[2024-09-10T15:01:37.342] Epoch: [1/3] Batch: [ 31/ 63] Loss: 2.06225
[2024-09-10T15:01:37.877] Epoch: [1/3] Batch: [ 41/ 63] Loss: 2.01902
[2024-09-10T15:01:38.278] Epoch: [1/3] Batch: [ 51/ 63] Loss: 2.03933
[2024-09-10T15:01:38.816] Epoch: [1/3] Batch: [ 61/ 63] Loss: 2.00332
[2024-09-10T15:01:39.694] Epoch: [1/3] Accuracy: 68.40000%
[2024-09-10T15:01:40.011] Epoch: [2/3] Batch: [ 1/ 63] Loss: 2.03461
[2024-09-10T15:01:41.023] Epoch: [2/3] Batch: [ 11/ 63] Loss: 2.05679
[2024-09-10T15:01:41.501] Epoch: [2/3] Batch: [ 21/ 63] Loss: 2.01941
[2024-09-10T15:01:42.114] Epoch: [2/3] Batch: [ 31/ 63] Loss: 2.00956
[2024-09-10T15:01:42.773] Epoch: [2/3] Batch: [ 41/ 63] Loss: 1.97900
[2024-09-10T15:01:43.446] Epoch: [2/3] Batch: [ 51/ 63] Loss: 1.98009
[2024-09-10T15:01:44.184] Epoch: [2/3] Batch: [ 61/ 63] Loss: 2.00311
[2024-09-10T15:01:44.975] Epoch: [2/3] Accuracy: 59.60000%
[2024-09-10T15:01:45.045] Epoch: [3/3] Batch: [ 1/ 63] Loss: 2.03107
[2024-09-10T15:01:46.087] Epoch: [3/3] Batch: [ 11/ 63] Loss: 1.93383
[2024-09-10T15:01:46.743] Epoch: [3/3] Batch: [ 21/ 63] Loss: 1.98916
[2024-09-10T15:01:47.392] Epoch: [3/3] Batch: [ 31/ 63] Loss: 1.97544
[2024-09-10T15:01:48.187] Epoch: [3/3] Batch: [ 41/ 63] Loss: 1.90061
[2024-09-10T15:01:48.829] Epoch: [3/3] Batch: [ 51/ 63] Loss: 1.89644
[2024-09-10T15:01:49.525] Epoch: [3/3] Batch: [ 61/ 63] Loss: 1.92080
[2024-09-10T15:01:50.557] Epoch: [3/3] Accuracy: 76.20000%
[2024-09-10T15:01:50.557] Training complete.
We can also train a continuous DEQ by passing in an ODE solver. Here we will use VCAB3()
which tend to be quite fast for continuous Neural Network problems.
train_model(VCAB3(), :deq);
[2024-09-10T15:01:50.987] warming up forward pass
[2024-09-10T15:02:17.293] warming up backward pass
[2024-09-10T15:02:33.292] warmup complete
[2024-09-10T15:02:34.807] Training Model: deq with Solver: VCAB3
[2024-09-10T15:02:34.807] Pretrain with unrolling to a depth of 5
[2024-09-10T15:03:12.338] Pretraining Batch: [ 1/ 63] Loss: 2.49227
[2024-09-10T15:03:12.883] Pretraining Batch: [ 11/ 63] Loss: 2.31392
[2024-09-10T15:03:13.278] Pretraining Batch: [ 21/ 63] Loss: 2.25770
[2024-09-10T15:03:13.462] Pretraining Batch: [ 31/ 63] Loss: 2.22166
[2024-09-10T15:03:13.810] Pretraining Batch: [ 41/ 63] Loss: 2.23305
[2024-09-10T15:03:14.039] Pretraining Batch: [ 51/ 63] Loss: 2.21107
[2024-09-10T15:03:15.071] Pretraining Batch: [ 61/ 63] Loss: 2.20224
[2024-09-10T15:03:15.112] Pretraining Batch: [ 1/ 63] Loss: 2.18605
[2024-09-10T15:03:15.257] Pretraining Batch: [ 11/ 63] Loss: 2.18501
[2024-09-10T15:03:15.396] Pretraining Batch: [ 21/ 63] Loss: 2.20914
[2024-09-10T15:03:16.276] Pretraining Batch: [ 31/ 63] Loss: 2.14085
[2024-09-10T15:03:16.511] Pretraining Batch: [ 41/ 63] Loss: 2.20098
[2024-09-10T15:03:16.676] Pretraining Batch: [ 51/ 63] Loss: 2.15264
[2024-09-10T15:03:16.857] Pretraining Batch: [ 61/ 63] Loss: 2.13023
[2024-09-10T15:03:18.290] Pretraining complete. Accuracy: 64.40000%
[2024-09-10T15:03:20.844] Epoch: [1/3] Batch: [ 1/ 63] Loss: 2.09923
[2024-09-10T15:03:23.869] Epoch: [1/3] Batch: [ 11/ 63] Loss: 2.13908
[2024-09-10T15:03:25.071] Epoch: [1/3] Batch: [ 21/ 63] Loss: 2.07400
[2024-09-10T15:03:26.421] Epoch: [1/3] Batch: [ 31/ 63] Loss: 2.12320
[2024-09-10T15:03:29.206] Epoch: [1/3] Batch: [ 41/ 63] Loss: 2.05816
[2024-09-10T15:03:30.425] Epoch: [1/3] Batch: [ 51/ 63] Loss: 2.06710
[2024-09-10T15:03:32.666] Epoch: [1/3] Batch: [ 61/ 63] Loss: 2.09572
[2024-09-10T15:03:34.783] Epoch: [1/3] Accuracy: 67.60000%
[2024-09-10T15:03:34.822] Epoch: [2/3] Batch: [ 1/ 63] Loss: 2.10152
[2024-09-10T15:03:37.124] Epoch: [2/3] Batch: [ 11/ 63] Loss: 2.07261
[2024-09-10T15:03:38.563] Epoch: [2/3] Batch: [ 21/ 63] Loss: 2.09647
[2024-09-10T15:03:40.756] Epoch: [2/3] Batch: [ 31/ 63] Loss: 2.05842
[2024-09-10T15:03:42.405] Epoch: [2/3] Batch: [ 41/ 63] Loss: 2.11158
[2024-09-10T15:03:45.101] Epoch: [2/3] Batch: [ 51/ 63] Loss: 2.08569
[2024-09-10T15:03:46.649] Epoch: [2/3] Batch: [ 61/ 63] Loss: 2.04567
[2024-09-10T15:03:49.122] Epoch: [2/3] Accuracy: 66.80000%
[2024-09-10T15:03:49.176] Epoch: [3/3] Batch: [ 1/ 63] Loss: 2.13084
[2024-09-10T15:03:49.754] Epoch: [3/3] Batch: [ 11/ 63] Loss: 2.06870
[2024-09-10T15:03:52.001] Epoch: [3/3] Batch: [ 21/ 63] Loss: 2.01065
[2024-09-10T15:03:52.675] Epoch: [3/3] Batch: [ 31/ 63] Loss: 2.03794
[2024-09-10T15:03:55.171] Epoch: [3/3] Batch: [ 41/ 63] Loss: 2.03897
[2024-09-10T15:03:56.712] Epoch: [3/3] Batch: [ 51/ 63] Loss: 2.10170
[2024-09-10T15:03:57.988] Epoch: [3/3] Batch: [ 61/ 63] Loss: 2.03757
[2024-09-10T15:04:00.434] Epoch: [3/3] Accuracy: 70.80000%
[2024-09-10T15:04:00.434] Training complete.
This code is setup to allow playing around with different DEQ models. Try modifying the model_type
argument to train_model
to :skipdeq
or :deq
to see how the model behaves. You can also try different solvers from NonlinearSolve.jl and OrdinaryDiffEq.jl! Even 3rd party solvers from Sundials.jl will work, just remember to use CPU for those.