Modelling Equilibrium Models with Reduced State Size

Sometimes we want don't want to solve a root finding problem with the full state size. This will often be faster, since the size of the root finding problem is reduced. We will use the same MNIST example as before, but this time we will use a reduced state size.

using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
      Random, Optimisers, Zygote, LinearSolve, Dates, Printf, Setfield, OneHotArrays
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs
using LuxCUDA # For NVIDIA GPU support

CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true

const cdev = cpu_device()
const gdev = gpu_device()

function loadmnist(batchsize, train_split)
    N = 2500
    dataset = MNIST(; split=:train)
    imgs = dataset.features[:, :, 1:N]
    labels_raw = dataset.targets[1:N]

    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
    y_data = onehotbatch(labels_raw, 0:9)
    (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
        # Don't shuffle the test data
        DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)

Now we will define the construct model function. Here we will use Dense Layers and downsample the features using the init kwarg.

function construct_model(solver; model_type::Symbol=:regdeq)
    down = Chain(FlattenLayer(), Dense(784 => 512, gelu))

    # The input layer of the DEQ
    deq_model = Chain(
        Parallel(+,
            Dense(
                128 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)),   # Reduced dim of `128`
            Dense(
                512 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01))),  # Original dim of `512`
        Dense(64 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)),
        Dense(64 => 128; use_bias=false, init_weight=truncated_normal(; std=0.01)))       # Return the reduced dim of `128`

    if model_type === :skipdeq
        init = Dense(
            512 => 128, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01))
    elseif model_type === :regdeq
        error(":regdeq is not supported for reduced dim models")
    else
        # This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here
        # we are only using Zygote so this is fine.
        init = WrappedFunction(x -> Zygote.@ignore(fill!(
            similar(x, 128, size(x, 2)), false)))
    end

    deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
        linsolve_kwargs=(; maxiters=10), maxiters=10)

    classifier = Chain(Dense(128 => 128, gelu), Dense(128, 10))

    model = Chain(; down, deq, classifier)

    # For NVIDIA GPUs this directly generates the parameters on the GPU
    rng = Random.default_rng() |> gdev
    ps, st = Lux.setup(rng, model)

    # Warmup the forward and backward passes
    x = randn(rng, Float32, 28, 28, 1, 2)
    y = onehotbatch(rand(Random.default_rng(), 0:9, 2), 0:9) |> gdev

    @printf "[%s] warming up forward pass\n" string(now())
    loss_function(model, ps, st, (x, y))
    @printf "[%s] warming up backward pass\n" string(now())
    Zygote.gradient(first ∘ loss_function, model, ps, st, (x, y))
    @printf "[%s] warmup complete\n" string(now())

    return model, ps, st
end
construct_model (generic function with 1 method)

Define some helper functions to train the model.

const logit_cross_entropy = CrossEntropyLoss(; logits=Val(true))
const mse_loss = MSELoss()

function loss_function(model, ps, st, (x, y))
    ŷ, st = model(x, ps, st)
    l1 = logit_cross_entropy(ŷ, y)
    l2 = mse_loss(st.deq.solution.z_star, st.deq.solution.u0) # Add in some regularization
    return l1 + eltype(l2)(0.01) * l2, st, (;)
end

function accuracy(model, ps, st, dataloader)
    total_correct, total = 0, 0
    st = Lux.testmode(st)
    for (x, y) in dataloader
        target_class = onecold(y)
        predicted_class = onecold(first(model(x, ps, st)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

function train_model(solver, model_type)
    model, ps, st = construct_model(solver; model_type)

    train_dataloader, test_dataloader = loadmnist(32, 0.8) |> gdev

    tstate = Training.TrainState(model, ps, st, Adam(0.0005))

    @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))

    @printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
    @set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5))

    for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader)
        _, loss, _, tstate = Training.single_train_step!(
            AutoZygote(), loss_function, (x, y), tstate)
        if i % 10 == 1
            @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss
        end
    end

    acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
    @printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc

    @set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(0))

    for epoch in 1:3
        for (i, (x, y)) in enumerate(train_dataloader)
            _, loss, _, tstate = Training.single_train_step!(
                AutoZygote(), loss_function, (x, y), tstate)
            if i % 10 == 1
                @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss
            end
        end

        acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
        @printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
    end

    @printf "[%s] Training complete.\n" string(now())

    return model, ps, tstate.states
end
train_model (generic function with 1 method)

Now we can train our model. We can't use :regdeq here currently, but we will support this in the future.

train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq)
[2024-09-10T17:07:02.084] warming up forward pass
[2024-09-10T17:07:09.359] warming up backward pass
[2024-09-10T17:07:41.087] warmup complete
[2024-09-10T17:07:42.531] Training Model: skipdeq with Solver: GeneralizedFirstOrderAlgorithm
[2024-09-10T17:07:42.531] Pretrain with unrolling to a depth of 5
[2024-09-10T17:08:44.039] Pretraining Batch: [   1/  63] Loss: 2.29567
[2024-09-10T17:08:44.074] Pretraining Batch: [  11/  63] Loss: 2.29412
[2024-09-10T17:08:44.107] Pretraining Batch: [  21/  63] Loss: 2.15275
[2024-09-10T17:08:44.140] Pretraining Batch: [  31/  63] Loss: 2.08518
[2024-09-10T17:08:44.173] Pretraining Batch: [  41/  63] Loss: 1.60935
[2024-09-10T17:08:44.205] Pretraining Batch: [  51/  63] Loss: 1.52030
[2024-09-10T17:08:44.238] Pretraining Batch: [  61/  63] Loss: 1.54317
[2024-09-10T17:08:44.249] Pretraining Batch: [   1/  63] Loss: 1.33805
[2024-09-10T17:08:44.282] Pretraining Batch: [  11/  63] Loss: 1.67373
[2024-09-10T17:08:44.315] Pretraining Batch: [  21/  63] Loss: 1.40876
[2024-09-10T17:08:44.349] Pretraining Batch: [  31/  63] Loss: 1.24748
[2024-09-10T17:08:44.383] Pretraining Batch: [  41/  63] Loss: 0.89204
[2024-09-10T17:08:44.417] Pretraining Batch: [  51/  63] Loss: 1.16815
[2024-09-10T17:08:44.451] Pretraining Batch: [  61/  63] Loss: 0.97810
[2024-09-10T17:08:44.940] Pretraining complete. Accuracy: 63.00000%
[2024-09-10T17:08:46.883] Epoch: [1/3] Batch: [   1/  63] Loss: 1.29054
[2024-09-10T17:08:46.990] Epoch: [1/3] Batch: [  11/  63] Loss: 1.03526
[2024-09-10T17:08:47.108] Epoch: [1/3] Batch: [  21/  63] Loss: 0.97615
[2024-09-10T17:08:47.252] Epoch: [1/3] Batch: [  31/  63] Loss: 1.08918
[2024-09-10T17:08:47.388] Epoch: [1/3] Batch: [  41/  63] Loss: 0.81719
[2024-09-10T17:08:47.519] Epoch: [1/3] Batch: [  51/  63] Loss: 0.65479
[2024-09-10T17:08:47.655] Epoch: [1/3] Batch: [  61/  63] Loss: 0.88301
[2024-09-10T17:08:48.093] Epoch: [1/3] Accuracy: 69.40000%
[2024-09-10T17:08:48.107] Epoch: [2/3] Batch: [   1/  63] Loss: 0.70414
[2024-09-10T17:08:48.241] Epoch: [2/3] Batch: [  11/  63] Loss: 0.99730
[2024-09-10T17:08:48.387] Epoch: [2/3] Batch: [  21/  63] Loss: 0.77190
[2024-09-10T17:08:48.529] Epoch: [2/3] Batch: [  31/  63] Loss: 0.54059
[2024-09-10T17:08:48.671] Epoch: [2/3] Batch: [  41/  63] Loss: 0.60561
[2024-09-10T17:08:48.814] Epoch: [2/3] Batch: [  51/  63] Loss: 0.41440
[2024-09-10T17:08:48.955] Epoch: [2/3] Batch: [  61/  63] Loss: 0.65514
[2024-09-10T17:08:49.121] Epoch: [2/3] Accuracy: 78.40000%
[2024-09-10T17:08:49.136] Epoch: [3/3] Batch: [   1/  63] Loss: 0.35704
[2024-09-10T17:08:49.273] Epoch: [3/3] Batch: [  11/  63] Loss: 0.38354
[2024-09-10T17:08:49.415] Epoch: [3/3] Batch: [  21/  63] Loss: 0.70121
[2024-09-10T17:08:49.559] Epoch: [3/3] Batch: [  31/  63] Loss: 0.33693
[2024-09-10T17:08:49.700] Epoch: [3/3] Batch: [  41/  63] Loss: 0.36984
[2024-09-10T17:08:49.839] Epoch: [3/3] Batch: [  51/  63] Loss: 0.53535
[2024-09-10T17:08:49.981] Epoch: [3/3] Batch: [  61/  63] Loss: 0.29085
[2024-09-10T17:08:50.132] Epoch: [3/3] Accuracy: 85.60000%
[2024-09-10T17:08:50.132] Training complete.
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq)
[2024-09-10T17:08:50.450] warming up forward pass
[2024-09-10T17:08:50.599] warming up backward pass
[2024-09-10T17:08:58.879] warmup complete
[2024-09-10T17:09:00.205] Training Model: deq with Solver: GeneralizedFirstOrderAlgorithm
[2024-09-10T17:09:00.205] Pretrain with unrolling to a depth of 5
[2024-09-10T17:09:54.821] Pretraining Batch: [   1/  63] Loss: 2.30717
[2024-09-10T17:09:54.854] Pretraining Batch: [  11/  63] Loss: 2.28074
[2024-09-10T17:09:54.886] Pretraining Batch: [  21/  63] Loss: 2.16493
[2024-09-10T17:09:54.917] Pretraining Batch: [  31/  63] Loss: 1.88609
[2024-09-10T17:09:54.947] Pretraining Batch: [  41/  63] Loss: 1.73695
[2024-09-10T17:09:54.978] Pretraining Batch: [  51/  63] Loss: 1.62721
[2024-09-10T17:09:55.009] Pretraining Batch: [  61/  63] Loss: 1.22730
[2024-09-10T17:09:55.018] Pretraining Batch: [   1/  63] Loss: 1.43074
[2024-09-10T17:09:55.049] Pretraining Batch: [  11/  63] Loss: 1.31279
[2024-09-10T17:09:55.080] Pretraining Batch: [  21/  63] Loss: 1.37096
[2024-09-10T17:09:55.112] Pretraining Batch: [  31/  63] Loss: 1.25390
[2024-09-10T17:09:55.142] Pretraining Batch: [  41/  63] Loss: 0.48196
[2024-09-10T17:09:55.173] Pretraining Batch: [  51/  63] Loss: 1.06383
[2024-09-10T17:09:55.204] Pretraining Batch: [  61/  63] Loss: 0.58507
[2024-09-10T17:09:55.334] Pretraining complete. Accuracy: 82.00000%
[2024-09-10T17:09:56.593] Epoch: [1/3] Batch: [   1/  63] Loss: 0.71791
[2024-09-10T17:09:56.705] Epoch: [1/3] Batch: [  11/  63] Loss: 0.57352
[2024-09-10T17:09:56.825] Epoch: [1/3] Batch: [  21/  63] Loss: 0.60163
[2024-09-10T17:09:56.944] Epoch: [1/3] Batch: [  31/  63] Loss: 0.49171
[2024-09-10T17:09:57.064] Epoch: [1/3] Batch: [  41/  63] Loss: 0.47504
[2024-09-10T17:09:57.183] Epoch: [1/3] Batch: [  51/  63] Loss: 0.59566
[2024-09-10T17:09:57.303] Epoch: [1/3] Batch: [  61/  63] Loss: 0.40391
[2024-09-10T17:09:57.478] Epoch: [1/3] Accuracy: 83.60000%
[2024-09-10T17:09:57.491] Epoch: [2/3] Batch: [   1/  63] Loss: 0.30122
[2024-09-10T17:09:57.613] Epoch: [2/3] Batch: [  11/  63] Loss: 0.77343
[2024-09-10T17:09:57.731] Epoch: [2/3] Batch: [  21/  63] Loss: 0.35490
[2024-09-10T17:09:58.002] Epoch: [2/3] Batch: [  31/  63] Loss: 0.49812
[2024-09-10T17:09:58.124] Epoch: [2/3] Batch: [  41/  63] Loss: 0.19764
[2024-09-10T17:09:58.246] Epoch: [2/3] Batch: [  51/  63] Loss: 0.38171
[2024-09-10T17:09:58.365] Epoch: [2/3] Batch: [  61/  63] Loss: 0.29050
[2024-09-10T17:09:58.489] Epoch: [2/3] Accuracy: 86.20000%
[2024-09-10T17:09:58.502] Epoch: [3/3] Batch: [   1/  63] Loss: 0.45327
[2024-09-10T17:09:58.619] Epoch: [3/3] Batch: [  11/  63] Loss: 0.28231
[2024-09-10T17:09:58.735] Epoch: [3/3] Batch: [  21/  63] Loss: 0.28942
[2024-09-10T17:09:58.853] Epoch: [3/3] Batch: [  31/  63] Loss: 0.28557
[2024-09-10T17:09:58.970] Epoch: [3/3] Batch: [  41/  63] Loss: 0.29017
[2024-09-10T17:09:59.090] Epoch: [3/3] Batch: [  51/  63] Loss: 0.18252
[2024-09-10T17:09:59.211] Epoch: [3/3] Batch: [  61/  63] Loss: 0.47422
[2024-09-10T17:09:59.341] Epoch: [3/3] Accuracy: 88.20000%
[2024-09-10T17:09:59.341] Training complete.