Training a Simple MNIST Classifier using Deep Equilibrium Models

We will train a simple Deep Equilibrium Model on MNIST. First we load a few packages.

using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
      Random, Optimisers, Zygote, LinearSolve, Dates, Printf, Setfield, OneHotArrays
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs
using LuxCUDA # For NVIDIA GPU support

CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
true

Setup device functions from Lux. See GPU Management for more details.

const cdev = cpu_device()
const gdev = gpu_device()
(::MLDataDevices.CUDADevice{Nothing}) (generic function with 5 methods)

We can now construct our dataloader. We are using only limited part of the data for demonstration.

function loadmnist(batchsize, train_split)
    N = 2500
    dataset = MNIST(; split=:train)
    imgs = dataset.features[:, :, 1:N]
    labels_raw = dataset.targets[1:N]

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)

Construct the Lux Neural Network containing a DEQ layer.

function construct_model(solver; model_type::Symbol=:deq)
    down = Chain(Conv((3, 3), 1 => 64, gelu; stride=1), GroupNorm(64, 64),
        Conv((4, 4), 64 => 64; stride=2, pad=1))

    # The input layer of the DEQ
    deq_model = Chain(
        Parallel(+,
            Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(),
                init_weight=truncated_normal(; std=0.01), use_bias=false),
            Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(),
                init_weight=truncated_normal(; std=0.01), use_bias=false)),
        Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(),
            init_weight=truncated_normal(; std=0.01), use_bias=false))

    if model_type === :skipdeq
        init = Conv((3, 3), 64 => 64, gelu; stride=1, pad=SamePad())
    elseif model_type === :regdeq
        init = nothing
    else
        init = missing
    end

    deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
        linsolve_kwargs=(; maxiters=10), maxiters=10)

    classifier = Chain(
        GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))

    model = Chain(; down, deq, classifier)

    # For NVIDIA GPUs this directly generates the parameters on the GPU
    rng = Random.default_rng() |> gdev
    ps, st = Lux.setup(rng, model)

    # Warmup the forward and backward passes
    x = randn(rng, Float32, 28, 28, 1, 2)
    y = onehotbatch(rand(Random.default_rng(), 0:9, 2), 0:9) |> gdev

    @printf "[%s] warming up forward pass\n" string(now())
    loss_function(model, ps, st, (x, y))
    @printf "[%s] warming up backward pass\n" string(now())
    Zygote.gradient(first ∘ loss_function, model, ps, st, (x, y))
    @printf "[%s] warmup complete\n" string(now())

    return model, ps, st
end
construct_model (generic function with 1 method)

Define some helper functions to train the model.

const logit_cross_entropy = CrossEntropyLoss(; logits=Val(true))
const mse_loss = MSELoss()

function loss_function(model, ps, st, (x, y))
    ŷ, st = model(x, ps, st)
    l1 = logit_cross_entropy(ŷ, y)
    l2 = mse_loss(st.deq.solution.z_star, st.deq.solution.u0) # Add in some regularization
    return l1 + eltype(l2)(0.01) * l2, st, (;)
end

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model(x, ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

function train_model(solver, model_type)
    model, ps, st = construct_model(solver; model_type)

    train_dataloader, test_dataloader = loadmnist(32, 0.8) |> gdev

    tstate = Training.TrainState(model, ps, st, Adam(0.0005))

    @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))

    @printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
    @set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5))

    for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader)
        _, loss, _, tstate = Training.single_train_step!(
            AutoZygote(), loss_function, (x, y), tstate)
        if i % 10 == 1
            @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss
        end
    end

    acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
    @printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc

    @set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(0))

    for epoch in 1:3
        for (i, (x, y)) in enumerate(train_dataloader)
            _, loss, _, tstate = Training.single_train_step!(
                AutoZygote(), loss_function, (x, y), tstate)
            if i % 10 == 1
                @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss
            end
        end

        acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
        @printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
    end

    @printf "[%s] Training complete.\n" string(now())

    return model, ps, tstate.states
end
train_model (generic function with 1 method)

Now we can train our model. First we will train a Discrete DEQ, which effectively means pass in a root finding algorithm. Typically most packages lack good nonlinear solvers, and end up using solvers like Broyden, but we can simply slap in any of the fancy solvers from NonlinearSolve.jl. Here we will use Newton-Krylov Method:

train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq);
[2024-09-10T14:58:50.387] warming up forward pass
[2024-09-10T14:59:27.548] warming up backward pass
[2024-09-10T15:00:32.510] warmup complete
[2024-09-10T15:00:35.290] Training Model: regdeq with Solver: GeneralizedFirstOrderAlgorithm
[2024-09-10T15:00:35.529] Pretrain with unrolling to a depth of 5
[2024-09-10T15:01:18.411] Pretraining Batch: [   1/  63] Loss: 2.29847
[2024-09-10T15:01:19.315] Pretraining Batch: [  11/  63] Loss: 2.19594
[2024-09-10T15:01:19.786] Pretraining Batch: [  21/  63] Loss: 2.30583
[2024-09-10T15:01:19.964] Pretraining Batch: [  31/  63] Loss: 2.21255
[2024-09-10T15:01:20.140] Pretraining Batch: [  41/  63] Loss: 2.17548
[2024-09-10T15:01:20.290] Pretraining Batch: [  51/  63] Loss: 2.19290
[2024-09-10T15:01:20.457] Pretraining Batch: [  61/  63] Loss: 2.14694
[2024-09-10T15:01:20.606] Pretraining Batch: [   1/  63] Loss: 2.15881
[2024-09-10T15:01:20.780] Pretraining Batch: [  11/  63] Loss: 2.18139
[2024-09-10T15:01:20.930] Pretraining Batch: [  21/  63] Loss: 2.14976
[2024-09-10T15:01:21.088] Pretraining Batch: [  31/  63] Loss: 2.09837
[2024-09-10T15:01:21.239] Pretraining Batch: [  41/  63] Loss: 2.19074
[2024-09-10T15:01:21.397] Pretraining Batch: [  51/  63] Loss: 2.09138
[2024-09-10T15:01:21.565] Pretraining Batch: [  61/  63] Loss: 2.03783
[2024-09-10T15:01:31.011] Pretraining complete. Accuracy: 66.60000%
[2024-09-10T15:01:35.177] Epoch: [1/3] Batch: [   1/  63] Loss: 1.99373
[2024-09-10T15:01:36.339] Epoch: [1/3] Batch: [  11/  63] Loss: 2.02998
[2024-09-10T15:01:36.802] Epoch: [1/3] Batch: [  21/  63] Loss: 2.06365
[2024-09-10T15:01:37.342] Epoch: [1/3] Batch: [  31/  63] Loss: 2.06225
[2024-09-10T15:01:37.877] Epoch: [1/3] Batch: [  41/  63] Loss: 2.01902
[2024-09-10T15:01:38.278] Epoch: [1/3] Batch: [  51/  63] Loss: 2.03933
[2024-09-10T15:01:38.816] Epoch: [1/3] Batch: [  61/  63] Loss: 2.00332
[2024-09-10T15:01:39.694] Epoch: [1/3] Accuracy: 68.40000%
[2024-09-10T15:01:40.011] Epoch: [2/3] Batch: [   1/  63] Loss: 2.03461
[2024-09-10T15:01:41.023] Epoch: [2/3] Batch: [  11/  63] Loss: 2.05679
[2024-09-10T15:01:41.501] Epoch: [2/3] Batch: [  21/  63] Loss: 2.01941
[2024-09-10T15:01:42.114] Epoch: [2/3] Batch: [  31/  63] Loss: 2.00956
[2024-09-10T15:01:42.773] Epoch: [2/3] Batch: [  41/  63] Loss: 1.97900
[2024-09-10T15:01:43.446] Epoch: [2/3] Batch: [  51/  63] Loss: 1.98009
[2024-09-10T15:01:44.184] Epoch: [2/3] Batch: [  61/  63] Loss: 2.00311
[2024-09-10T15:01:44.975] Epoch: [2/3] Accuracy: 59.60000%
[2024-09-10T15:01:45.045] Epoch: [3/3] Batch: [   1/  63] Loss: 2.03107
[2024-09-10T15:01:46.087] Epoch: [3/3] Batch: [  11/  63] Loss: 1.93383
[2024-09-10T15:01:46.743] Epoch: [3/3] Batch: [  21/  63] Loss: 1.98916
[2024-09-10T15:01:47.392] Epoch: [3/3] Batch: [  31/  63] Loss: 1.97544
[2024-09-10T15:01:48.187] Epoch: [3/3] Batch: [  41/  63] Loss: 1.90061
[2024-09-10T15:01:48.829] Epoch: [3/3] Batch: [  51/  63] Loss: 1.89644
[2024-09-10T15:01:49.525] Epoch: [3/3] Batch: [  61/  63] Loss: 1.92080
[2024-09-10T15:01:50.557] Epoch: [3/3] Accuracy: 76.20000%
[2024-09-10T15:01:50.557] Training complete.

We can also train a continuous DEQ by passing in an ODE solver. Here we will use VCAB3() which tend to be quite fast for continuous Neural Network problems.

train_model(VCAB3(), :deq);
[2024-09-10T15:01:50.987] warming up forward pass
[2024-09-10T15:02:17.293] warming up backward pass
[2024-09-10T15:02:33.292] warmup complete
[2024-09-10T15:02:34.807] Training Model: deq with Solver: VCAB3
[2024-09-10T15:02:34.807] Pretrain with unrolling to a depth of 5
[2024-09-10T15:03:12.338] Pretraining Batch: [   1/  63] Loss: 2.49227
[2024-09-10T15:03:12.883] Pretraining Batch: [  11/  63] Loss: 2.31392
[2024-09-10T15:03:13.278] Pretraining Batch: [  21/  63] Loss: 2.25770
[2024-09-10T15:03:13.462] Pretraining Batch: [  31/  63] Loss: 2.22166
[2024-09-10T15:03:13.810] Pretraining Batch: [  41/  63] Loss: 2.23305
[2024-09-10T15:03:14.039] Pretraining Batch: [  51/  63] Loss: 2.21107
[2024-09-10T15:03:15.071] Pretraining Batch: [  61/  63] Loss: 2.20224
[2024-09-10T15:03:15.112] Pretraining Batch: [   1/  63] Loss: 2.18605
[2024-09-10T15:03:15.257] Pretraining Batch: [  11/  63] Loss: 2.18501
[2024-09-10T15:03:15.396] Pretraining Batch: [  21/  63] Loss: 2.20914
[2024-09-10T15:03:16.276] Pretraining Batch: [  31/  63] Loss: 2.14085
[2024-09-10T15:03:16.511] Pretraining Batch: [  41/  63] Loss: 2.20098
[2024-09-10T15:03:16.676] Pretraining Batch: [  51/  63] Loss: 2.15264
[2024-09-10T15:03:16.857] Pretraining Batch: [  61/  63] Loss: 2.13023
[2024-09-10T15:03:18.290] Pretraining complete. Accuracy: 64.40000%
[2024-09-10T15:03:20.844] Epoch: [1/3] Batch: [   1/  63] Loss: 2.09923
[2024-09-10T15:03:23.869] Epoch: [1/3] Batch: [  11/  63] Loss: 2.13908
[2024-09-10T15:03:25.071] Epoch: [1/3] Batch: [  21/  63] Loss: 2.07400
[2024-09-10T15:03:26.421] Epoch: [1/3] Batch: [  31/  63] Loss: 2.12320
[2024-09-10T15:03:29.206] Epoch: [1/3] Batch: [  41/  63] Loss: 2.05816
[2024-09-10T15:03:30.425] Epoch: [1/3] Batch: [  51/  63] Loss: 2.06710
[2024-09-10T15:03:32.666] Epoch: [1/3] Batch: [  61/  63] Loss: 2.09572
[2024-09-10T15:03:34.783] Epoch: [1/3] Accuracy: 67.60000%
[2024-09-10T15:03:34.822] Epoch: [2/3] Batch: [   1/  63] Loss: 2.10152
[2024-09-10T15:03:37.124] Epoch: [2/3] Batch: [  11/  63] Loss: 2.07261
[2024-09-10T15:03:38.563] Epoch: [2/3] Batch: [  21/  63] Loss: 2.09647
[2024-09-10T15:03:40.756] Epoch: [2/3] Batch: [  31/  63] Loss: 2.05842
[2024-09-10T15:03:42.405] Epoch: [2/3] Batch: [  41/  63] Loss: 2.11158
[2024-09-10T15:03:45.101] Epoch: [2/3] Batch: [  51/  63] Loss: 2.08569
[2024-09-10T15:03:46.649] Epoch: [2/3] Batch: [  61/  63] Loss: 2.04567
[2024-09-10T15:03:49.122] Epoch: [2/3] Accuracy: 66.80000%
[2024-09-10T15:03:49.176] Epoch: [3/3] Batch: [   1/  63] Loss: 2.13084
[2024-09-10T15:03:49.754] Epoch: [3/3] Batch: [  11/  63] Loss: 2.06870
[2024-09-10T15:03:52.001] Epoch: [3/3] Batch: [  21/  63] Loss: 2.01065
[2024-09-10T15:03:52.675] Epoch: [3/3] Batch: [  31/  63] Loss: 2.03794
[2024-09-10T15:03:55.171] Epoch: [3/3] Batch: [  41/  63] Loss: 2.03897
[2024-09-10T15:03:56.712] Epoch: [3/3] Batch: [  51/  63] Loss: 2.10170
[2024-09-10T15:03:57.988] Epoch: [3/3] Batch: [  61/  63] Loss: 2.03757
[2024-09-10T15:04:00.434] Epoch: [3/3] Accuracy: 70.80000%
[2024-09-10T15:04:00.434] Training complete.

This code is setup to allow playing around with different DEQ models. Try modifying the model_type argument to train_model to :skipdeq or :deq to see how the model behaves. You can also try different solvers from NonlinearSolve.jl and OrdinaryDiffEq.jl! Even 3rd party solvers from Sundials.jl will work, just remember to use CPU for those.