Convolutional Neural ODE MNIST Classifier on GPU
Training a Convolutional Neural Net Classifier for MNIST using a neural ordinary differential equation NeuralODE on GPUs with Minibatching.
For a step-by-step tutorial see the tutorial on the MNIST Neural ODE Classification Tutorial using Fully Connected Layers.
using DiffEqFlux, ComponentArrays, CUDA, Zygote, MLDatasets, OrdinaryDiffEq,
Printf, LuxCUDA, Random, MLUtils, OneHotArrays
using Optimization, OptimizationOptimisers
using MLDatasets: MNIST
const cdev = cpu_device()
const gdev = gpu_device()
logitcrossentropy = CrossEntropyLoss(; logits = Val(true))
CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
function loadmnist(batchsize)
# Load MNIST
dataset = MNIST(; split = :train)[1:2000] # Partial load for demonstration
imgs = dataset.features
labels_raw = dataset.targets
# Process images into (H, W, C, BS) batches
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
return DataLoader(mapobs(gdev, (x_data, y_data)); batchsize, shuffle = true)
end
dataloader = loadmnist(128)
down = Chain(
Conv((3, 3), 1 => 12, tanh; stride = 1),
GroupNorm(12, 3),
Conv((4, 4), 12 => 64, tanh; stride = 2, pad = 1),
GroupNorm(64, 4),
Conv((4, 4), 64 => 256; stride = 2, pad = 1)
)
dudt = Chain(
Conv((3, 3), 256 => 64, tanh; pad = SamePad()),
Conv((3, 3), 64 => 256, tanh; pad = SamePad())
)
fc = Chain(GroupNorm(256, 4, tanh), MeanPool((6, 6)), FlattenLayer(), Dense(256, 10))
nn_ode = NeuralODE(dudt, (0.0f0, 1.0f0), Tsit5(); save_everystep = false,
sensealg = BacksolveAdjoint(; autojacvec = ZygoteVJP()),
reltol = 1e-5, abstol = 1e-6, save_start = false)
solution_to_array(sol) = sol.u[end]
m = Chain(
down,
nn_ode,
solution_to_array,
fc
)
ps, st = Lux.setup(Xoshiro(0), m);
ps = ComponentArray(ps) |> gdev;
st = st |> gdev;
# To understand the intermediate NN-ODE layer, we can examine it's dimensionality
img, lab = first(dataloader);
x_m, _ = m(img, ps, st);
classify(x) = argmax.(eachcol(x))
function accuracy(model, data, ps, st; n_batches = 10)
total_correct = 0
total = 0
st = Lux.testmode(st)
for (x, y) in collect(data)[1:min(n_batches, length(data))]
target_class = classify(cdev(y))
predicted_class = classify(cdev(first(model(x, ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end
# burn in accuracy
accuracy(m, ((img, lab),), ps, st)
function loss_function(ps, data)
(x, y) = data
pred, _ = m(x, ps, st)
return logitcrossentropy(pred, y)
end
# burn in loss
loss_function(ps, (img, lab))
opt = OptimizationOptimisers.Adam(0.005)
iter = 0
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader);
function callback(state, l)
global iter += 1
iter % 10 == 0 &&
@info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
return false
end
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt; epochs = 5, callback)
acc = accuracy(m, dataloader, res.u, st)
0.815625