Hamiltonian Neural Network

Hamiltonian Neural Networks introduced in [1] allow models to "learn and respect exact conservation laws in an unsupervised manner". In this example, we will train a model to learn the Hamiltonian for a 1D Spring mass system. This system is described by the equation:

\[m\ddot x + kx = 0\]

Now we make some simplifying assumptions, and assign $m = 1$ and $k = 1$. Analytically solving this equation, we get $x = sin(t)$. Hence, $q = sin(t)$, and $p = cos(t)$. Using these solutions, we generate our dataset and fit the NeuralHamiltonianDE to learn the dynamics of this system.

Copy-Pasteable Code

Before getting to the explanation, here's some code to start with. We will follow a full explanation of the definition and training process:

using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
      ComponentArrays, Optimization, OptimizationOptimisers, MLUtils

t = range(0.0f0, 1.0f0; length = 1024)
π_32 = Float32(π)
q_t = reshape(sin.(2π_32 * t), 1, :)
p_t = reshape(cos.(2π_32 * t), 1, :)
dqdt = 2π_32 .* p_t
dpdt = -2π_32 .* q_t

data = cat(q_t, p_t; dims = 1)
target = cat(dqdt, dpdt; dims = 1)
B = 256
NEPOCHS = 500
dataloader = DataLoader((data, target); batchsize = B)

hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote())
ps, st = Lux.setup(Xoshiro(0), hnn)
ps_c = ps |> ComponentArray

opt = OptimizationOptimisers.Adam(0.01f0)

function loss_function(ps, databatch)
    data, target = databatch
    pred, st_ = hnn(data, ps, st)
    return mean(abs2, pred .- target)
end

function callback(state, loss)
    println("[Hamiltonian NN] Loss: ", loss)
    return false
end

opt_func = OptimizationFunction(loss_function, Optimization.AutoForwardDiff())
opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)

res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS)

ps_trained = res.u

model = NeuralODE(
    hnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, save_start = true, saveat = t)

pred = Array(first(model(data[:, 1], ps_trained, st)))
plot(data[1, :], data[2, :]; lw = 4, label = "Original")
plot!(pred[1, :], pred[2, :]; lw = 4, label = "Predicted")
xlabel!("Position (q)")
ylabel!("Momentum (p)")
Example block output

Step by Step Explanation

Data Generation

The HNN predicts the gradients $(\dot q, \dot p)$ given $(q, p)$. Hence, we generate the pairs $(q, p)$ using the equations given at the top. Additionally, to supervise the training, we also generate the gradients. Next, we use Flux DataLoader for automatically batching our dataset.

using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
      ComponentArrays, Optimization, OptimizationOptimisers, MLUtils

t = range(0.0f0, 1.0f0; length = 1024)
π_32 = Float32(π)
q_t = reshape(sin.(2π_32 * t), 1, :)
p_t = reshape(cos.(2π_32 * t), 1, :)
dqdt = 2π_32 .* p_t
dpdt = -2π_32 .* q_t

data = cat(q_t, p_t; dims = 1)
target = cat(dqdt, dpdt; dims = 1)
B = 256
NEPOCHS = 500
dataloader = DataLoader((data, target); batchsize = B)
4-element DataLoader(::Tuple{Matrix{Float32}, Matrix{Float32}}, batchsize=256)
  with first element:
  (2×256 Matrix{Float32}, 2×256 Matrix{Float32},)

Training the HamiltonianNN

We parameterize the with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization.

hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote())
ps, st = Lux.setup(Xoshiro(0), hnn)
ps_c = ps |> ComponentArray
hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st)

opt = OptimizationOptimisers.Adam(0.005f0)

function loss_function(ps, databatch)
    (data, target) = databatch
    pred = hnn_stateful(data, ps)
    return mean(abs2, pred .- target)
end

function callback(state, loss)
    println("[Hamiltonian NN] Loss: ", loss)
    return false
end

opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)

res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS)

ps_trained = res.u
ComponentVector{Float32}(block1 = (dense = (weight = Float32[-2.1370213 0.0026516672; -0.14739923 -1.3914757; … ; 1.1863977 0.039584033; 2.157514 -1.4245019], bias = Float32[0.33639288, 0.22010256, -0.12450862, 0.3884359, 0.5799375, 0.39842856, -0.6958851, 0.07831879, -0.30917966, -0.5286344  …  -0.009497202, 0.29972395, -0.21489885, -0.68336666, -0.44730315, -0.25739992, 0.39360783, 0.5259282, 0.4787088, -0.20453048])), block2 = (dense = (weight = Float32[-0.032559782 0.0072834897 … -0.24280524 0.004243767], bias = Float32[-0.009822125])))

Solving the ODE using trained HNN

In order to visualize the learned trajectories, we need to solve the ODE. We will use the NeuralODE layer with HamiltonianNN layer, and solves the ODE.

model = NeuralODE(
    hnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, save_start = true, saveat = t)

pred = Array(first(model(data[:, 1], ps_trained, st)))
plot(data[1, :], data[2, :]; lw = 4, label = "Original")
plot!(pred[1, :], pred[2, :]; lw = 4, label = "Predicted")
xlabel!("Position (q)")
ylabel!("Momentum (p)")
Example block output

References

[1] Greydanus, Samuel, Misko Dzamba, and Jason Yosinski. "Hamiltonian Neural Networks." Advances in Neural Information Processing Systems 32 (2019): 15379-15389.