Saving and loading models

ReservoirComputing.jl borrows the same structure as Lux.jl, where each model is defined by its own parameters ps and states st. Therefore, in order to save a model, it suffices to save the hyperparameters that define the model, its parameters ps, and its states st. In this example we are going to show how saving and loading a model can be done leveraging JLD2.jl.

Let's assume you have trained an ESN, and you want to save it. Following the getting started example we are going to train the ESN on the Lorenz system:

using OrdinaryDiffEq
using Plots
using Random
using ReservoirComputing

Random.seed!(42)
rng = MersenneTwister(17)

function lorenz(du, u, p, t)
    du[1] = p[1] * (u[2] - u[1])
    du[2] = u[1] * (p[2] - u[3]) - u[2]
    du[3] = u[1] * u[2] - p[3] * u[3]
end

prob = ODEProblem(lorenz, [1.0f0, 0.0f0, 0.0f0], (0.0, 200.0), Float32[10.0, 28.0, 8/3])
data = Array(solve(prob, ABM54(); dt=0.02f0))
shift = 300
train_len = 5000
predict_len = 1250

input_data = data[:, shift:(shift + train_len - 1)]
target_data = data[:, (shift + 1):(shift + train_len)]
test = data[:, (shift + train_len):(shift + train_len + predict_len - 1)]

esn = ESN(3, 300, 3; init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300),
    state_modifiers=NLAT2)

ps, st = setup(rng, esn)
ps, st = train!(esn, input_data, target_data, ps, st)
((reservoir = (input_matrix = Float32[0.07099056 0.04901688 0.031145955; 0.044342138 -0.081324674 0.07599063; … ; 0.04174912 0.05752921 -0.008878255; -0.017985344 -0.036644936 -0.06222496], reservoir_matrix = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]), states_modifiers = (NamedTuple(),), readout = (weight = [-2.381352481636387 -3.746526159379389 … -0.15392978550987516 0.14061616801542676; -0.9846424350897656 -0.4212137157112401 … 0.32560012500937846 -0.432274471742516; -15.194790567151385 -12.533751629126279 … -0.2521016041391559 -2.55454306753343],)), (reservoir = (cell = (rng = Random.MersenneTwister(17, (0, 223340, 222338, 61)),), carry = (Float32[0.79934967; 0.9464035; … ; -0.17016153; -0.5525689;;],)), states_modifiers = (NamedTuple(),), readout = NamedTuple(), states = Float32[0.53167844 0.3456088 … 0.8024428 0.79934967; 0.7672529 0.9781053 … 0.9548178 0.9464035; … ; 0.1935248 -0.012674722 … -0.22481114 -0.18306246; -0.97238624 -0.07230291 … -0.6005286 -0.5525689]))

Now that we have a trained model we want to save both the parameters and states, as well as the hyperparameters that define the model. We can do so by creating an additional NamedTuple for the hyperparameters

spec = (
  in_size = 3,
  res_size = 300,
  out_size = 3,
  radius = 1.2,
  sparsity = 6/300,
  state_modifiers = :NLAT2,
  # include any non-default knobs you used:
  # leak_coefficient = 1.0, input_scaling = 0.1, use_bias=false, etc.
)
(in_size = 3, res_size = 300, out_size = 3, radius = 1.2, sparsity = 0.02, state_modifiers = :NLAT2)

We can now save the model:

using JLD2
@save "esn_trained.jld2" ps st spec

In order to load the model and use it we still rely on JLD2.jl, using the @load macro:

@load "esn_trained.jld2" ps st spec

# Rebuild the same ESN architecture (must match ps structure)
esn = ESN(spec.in_size, spec.res_size, spec.out_size;
          init_reservoir=rand_sparse(; radius=spec.radius, sparsity=spec.sparsity),
          state_modifiers = getfield(ReservoirComputing, spec.state_modifiers))

# Now you can predict using the loaded ps/st
output, st = predict(esn, predict_len, ps, st; initialdata=test[:, 1])
(Float32[-1.1813505 -0.837569 … 6.887872 8.51222; 0.77978486 0.65421754 … 12.693543 15.229715; 23.08223 21.868599 … 7.9276137 9.036165], (reservoir = (cell = (rng = Random.MersenneTwister(17, (0, 223340, 222338, 61)),), carry = ([0.9469336511112671; -0.7048989240985005; … ; 0.9114587730213789; -0.5156708007345061;;],)), states_modifiers = (NamedTuple(),), readout = NamedTuple()))