Modelling Equilibrium Models with Reduced State Size
Sometimes we want don't want to solve a root finding problem with the full state size. This will often be faster, since the size of the root finding problem is reduced. We will use the same MNIST example as before, but this time we will use a reduced state size.
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Random, Optimisers, Zygote, LinearSolve, Dates, Printf, Setfield, OneHotArrays
using MLDatasets: MNIST
using MLUtils: DataLoader, splitobs
using LuxCUDA # For NVIDIA GPU support
CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
const cdev = cpu_device()
const gdev = gpu_device()
function loadmnist(batchsize, train_split)
N = 2500
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]
# Process images into (H,W,C,BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end
loadmnist (generic function with 1 method)
Now we will define the construct model function. Here we will use Dense Layers and downsample the features using the init
kwarg.
function construct_model(solver; model_type::Symbol=:regdeq)
down = Chain(FlattenLayer(), Dense(784 => 512, gelu))
# The input layer of the DEQ
deq_model = Chain(
Parallel(+,
Dense(
128 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)), # Reduced dim of `128`
Dense(
512 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01))), # Original dim of `512`
Dense(64 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)),
Dense(64 => 128; use_bias=false, init_weight=truncated_normal(; std=0.01))) # Return the reduced dim of `128`
if model_type === :skipdeq
init = Dense(
512 => 128, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01))
elseif model_type === :regdeq
error(":regdeq is not supported for reduced dim models")
else
# This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here
# we are only using Zygote so this is fine.
init = WrappedFunction(x -> Zygote.@ignore(fill!(
similar(x, 128, size(x, 2)), false)))
end
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
linsolve_kwargs=(; maxiters=10), maxiters=10)
classifier = Chain(Dense(128 => 128, gelu), Dense(128, 10))
model = Chain(; down, deq, classifier)
# For NVIDIA GPUs this directly generates the parameters on the GPU
rng = Random.default_rng() |> gdev
ps, st = Lux.setup(rng, model)
# Warmup the forward and backward passes
x = randn(rng, Float32, 28, 28, 1, 2)
y = onehotbatch(rand(Random.default_rng(), 0:9, 2), 0:9) |> gdev
@printf "[%s] warming up forward pass\n" string(now())
loss_function(model, ps, st, (x, y))
@printf "[%s] warming up backward pass\n" string(now())
Zygote.gradient(first ∘ loss_function, model, ps, st, (x, y))
@printf "[%s] warmup complete\n" string(now())
return model, ps, st
end
construct_model (generic function with 1 method)
Define some helper functions to train the model.
const logit_cross_entropy = CrossEntropyLoss(; logits=Val(true))
const mse_loss = MSELoss()
function loss_function(model, ps, st, (x, y))
ŷ, st = model(x, ps, st)
l1 = logit_cross_entropy(ŷ, y)
l2 = mse_loss(st.deq.solution.z_star, st.deq.solution.u0) # Add in some regularization
return l1 + eltype(l2)(0.01) * l2, st, (;)
end
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model(x, ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
function train_model(solver, model_type)
model, ps, st = construct_model(solver; model_type)
train_dataloader, test_dataloader = loadmnist(32, 0.8) |> gdev
tstate = Training.TrainState(model, ps, st, Adam(0.0005))
@printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))
@printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5))
for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader)
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), loss_function, (x, y), tstate)
if i % 10 == 1
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss
end
end
acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
@printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc
@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(0))
for epoch in 1:3
for (i, (x, y)) in enumerate(train_dataloader)
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), loss_function, (x, y), tstate)
if i % 10 == 1
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss
end
end
acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
@printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
end
@printf "[%s] Training complete.\n" string(now())
return model, ps, tstate.states
end
train_model (generic function with 1 method)
Now we can train our model. We can't use :regdeq
here currently, but we will support this in the future.
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq)
[2024-09-10T17:07:02.084] warming up forward pass
[2024-09-10T17:07:09.359] warming up backward pass
[2024-09-10T17:07:41.087] warmup complete
[2024-09-10T17:07:42.531] Training Model: skipdeq with Solver: GeneralizedFirstOrderAlgorithm
[2024-09-10T17:07:42.531] Pretrain with unrolling to a depth of 5
[2024-09-10T17:08:44.039] Pretraining Batch: [ 1/ 63] Loss: 2.29567
[2024-09-10T17:08:44.074] Pretraining Batch: [ 11/ 63] Loss: 2.29412
[2024-09-10T17:08:44.107] Pretraining Batch: [ 21/ 63] Loss: 2.15275
[2024-09-10T17:08:44.140] Pretraining Batch: [ 31/ 63] Loss: 2.08518
[2024-09-10T17:08:44.173] Pretraining Batch: [ 41/ 63] Loss: 1.60935
[2024-09-10T17:08:44.205] Pretraining Batch: [ 51/ 63] Loss: 1.52030
[2024-09-10T17:08:44.238] Pretraining Batch: [ 61/ 63] Loss: 1.54317
[2024-09-10T17:08:44.249] Pretraining Batch: [ 1/ 63] Loss: 1.33805
[2024-09-10T17:08:44.282] Pretraining Batch: [ 11/ 63] Loss: 1.67373
[2024-09-10T17:08:44.315] Pretraining Batch: [ 21/ 63] Loss: 1.40876
[2024-09-10T17:08:44.349] Pretraining Batch: [ 31/ 63] Loss: 1.24748
[2024-09-10T17:08:44.383] Pretraining Batch: [ 41/ 63] Loss: 0.89204
[2024-09-10T17:08:44.417] Pretraining Batch: [ 51/ 63] Loss: 1.16815
[2024-09-10T17:08:44.451] Pretraining Batch: [ 61/ 63] Loss: 0.97810
[2024-09-10T17:08:44.940] Pretraining complete. Accuracy: 63.00000%
[2024-09-10T17:08:46.883] Epoch: [1/3] Batch: [ 1/ 63] Loss: 1.29054
[2024-09-10T17:08:46.990] Epoch: [1/3] Batch: [ 11/ 63] Loss: 1.03526
[2024-09-10T17:08:47.108] Epoch: [1/3] Batch: [ 21/ 63] Loss: 0.97615
[2024-09-10T17:08:47.252] Epoch: [1/3] Batch: [ 31/ 63] Loss: 1.08918
[2024-09-10T17:08:47.388] Epoch: [1/3] Batch: [ 41/ 63] Loss: 0.81719
[2024-09-10T17:08:47.519] Epoch: [1/3] Batch: [ 51/ 63] Loss: 0.65479
[2024-09-10T17:08:47.655] Epoch: [1/3] Batch: [ 61/ 63] Loss: 0.88301
[2024-09-10T17:08:48.093] Epoch: [1/3] Accuracy: 69.40000%
[2024-09-10T17:08:48.107] Epoch: [2/3] Batch: [ 1/ 63] Loss: 0.70414
[2024-09-10T17:08:48.241] Epoch: [2/3] Batch: [ 11/ 63] Loss: 0.99730
[2024-09-10T17:08:48.387] Epoch: [2/3] Batch: [ 21/ 63] Loss: 0.77190
[2024-09-10T17:08:48.529] Epoch: [2/3] Batch: [ 31/ 63] Loss: 0.54059
[2024-09-10T17:08:48.671] Epoch: [2/3] Batch: [ 41/ 63] Loss: 0.60561
[2024-09-10T17:08:48.814] Epoch: [2/3] Batch: [ 51/ 63] Loss: 0.41440
[2024-09-10T17:08:48.955] Epoch: [2/3] Batch: [ 61/ 63] Loss: 0.65514
[2024-09-10T17:08:49.121] Epoch: [2/3] Accuracy: 78.40000%
[2024-09-10T17:08:49.136] Epoch: [3/3] Batch: [ 1/ 63] Loss: 0.35704
[2024-09-10T17:08:49.273] Epoch: [3/3] Batch: [ 11/ 63] Loss: 0.38354
[2024-09-10T17:08:49.415] Epoch: [3/3] Batch: [ 21/ 63] Loss: 0.70121
[2024-09-10T17:08:49.559] Epoch: [3/3] Batch: [ 31/ 63] Loss: 0.33693
[2024-09-10T17:08:49.700] Epoch: [3/3] Batch: [ 41/ 63] Loss: 0.36984
[2024-09-10T17:08:49.839] Epoch: [3/3] Batch: [ 51/ 63] Loss: 0.53535
[2024-09-10T17:08:49.981] Epoch: [3/3] Batch: [ 61/ 63] Loss: 0.29085
[2024-09-10T17:08:50.132] Epoch: [3/3] Accuracy: 85.60000%
[2024-09-10T17:08:50.132] Training complete.
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq)
[2024-09-10T17:08:50.450] warming up forward pass
[2024-09-10T17:08:50.599] warming up backward pass
[2024-09-10T17:08:58.879] warmup complete
[2024-09-10T17:09:00.205] Training Model: deq with Solver: GeneralizedFirstOrderAlgorithm
[2024-09-10T17:09:00.205] Pretrain with unrolling to a depth of 5
[2024-09-10T17:09:54.821] Pretraining Batch: [ 1/ 63] Loss: 2.30717
[2024-09-10T17:09:54.854] Pretraining Batch: [ 11/ 63] Loss: 2.28074
[2024-09-10T17:09:54.886] Pretraining Batch: [ 21/ 63] Loss: 2.16493
[2024-09-10T17:09:54.917] Pretraining Batch: [ 31/ 63] Loss: 1.88609
[2024-09-10T17:09:54.947] Pretraining Batch: [ 41/ 63] Loss: 1.73695
[2024-09-10T17:09:54.978] Pretraining Batch: [ 51/ 63] Loss: 1.62721
[2024-09-10T17:09:55.009] Pretraining Batch: [ 61/ 63] Loss: 1.22730
[2024-09-10T17:09:55.018] Pretraining Batch: [ 1/ 63] Loss: 1.43074
[2024-09-10T17:09:55.049] Pretraining Batch: [ 11/ 63] Loss: 1.31279
[2024-09-10T17:09:55.080] Pretraining Batch: [ 21/ 63] Loss: 1.37096
[2024-09-10T17:09:55.112] Pretraining Batch: [ 31/ 63] Loss: 1.25390
[2024-09-10T17:09:55.142] Pretraining Batch: [ 41/ 63] Loss: 0.48196
[2024-09-10T17:09:55.173] Pretraining Batch: [ 51/ 63] Loss: 1.06383
[2024-09-10T17:09:55.204] Pretraining Batch: [ 61/ 63] Loss: 0.58507
[2024-09-10T17:09:55.334] Pretraining complete. Accuracy: 82.00000%
[2024-09-10T17:09:56.593] Epoch: [1/3] Batch: [ 1/ 63] Loss: 0.71791
[2024-09-10T17:09:56.705] Epoch: [1/3] Batch: [ 11/ 63] Loss: 0.57352
[2024-09-10T17:09:56.825] Epoch: [1/3] Batch: [ 21/ 63] Loss: 0.60163
[2024-09-10T17:09:56.944] Epoch: [1/3] Batch: [ 31/ 63] Loss: 0.49171
[2024-09-10T17:09:57.064] Epoch: [1/3] Batch: [ 41/ 63] Loss: 0.47504
[2024-09-10T17:09:57.183] Epoch: [1/3] Batch: [ 51/ 63] Loss: 0.59566
[2024-09-10T17:09:57.303] Epoch: [1/3] Batch: [ 61/ 63] Loss: 0.40391
[2024-09-10T17:09:57.478] Epoch: [1/3] Accuracy: 83.60000%
[2024-09-10T17:09:57.491] Epoch: [2/3] Batch: [ 1/ 63] Loss: 0.30122
[2024-09-10T17:09:57.613] Epoch: [2/3] Batch: [ 11/ 63] Loss: 0.77343
[2024-09-10T17:09:57.731] Epoch: [2/3] Batch: [ 21/ 63] Loss: 0.35490
[2024-09-10T17:09:58.002] Epoch: [2/3] Batch: [ 31/ 63] Loss: 0.49812
[2024-09-10T17:09:58.124] Epoch: [2/3] Batch: [ 41/ 63] Loss: 0.19764
[2024-09-10T17:09:58.246] Epoch: [2/3] Batch: [ 51/ 63] Loss: 0.38171
[2024-09-10T17:09:58.365] Epoch: [2/3] Batch: [ 61/ 63] Loss: 0.29050
[2024-09-10T17:09:58.489] Epoch: [2/3] Accuracy: 86.20000%
[2024-09-10T17:09:58.502] Epoch: [3/3] Batch: [ 1/ 63] Loss: 0.45327
[2024-09-10T17:09:58.619] Epoch: [3/3] Batch: [ 11/ 63] Loss: 0.28231
[2024-09-10T17:09:58.735] Epoch: [3/3] Batch: [ 21/ 63] Loss: 0.28942
[2024-09-10T17:09:58.853] Epoch: [3/3] Batch: [ 31/ 63] Loss: 0.28557
[2024-09-10T17:09:58.970] Epoch: [3/3] Batch: [ 41/ 63] Loss: 0.29017
[2024-09-10T17:09:59.090] Epoch: [3/3] Batch: [ 51/ 63] Loss: 0.18252
[2024-09-10T17:09:59.211] Epoch: [3/3] Batch: [ 61/ 63] Loss: 0.47422
[2024-09-10T17:09:59.341] Epoch: [3/3] Accuracy: 88.20000%
[2024-09-10T17:09:59.341] Training complete.