Deep Echo State Networks

In this example we showcase how to build a deep echo state network (DeepESN) following the work of (Gallicchio and Micheli, 2017). The DeepESN stacks reservoirs on top of each other, feeding the output from one into the next. In the version implemented in ReservoirComputing.jl the final state is the state used for training.

Lorenz Example

We are going to reuse the Lorenz data used in the Lorenz System Forecasting example.

using OrdinaryDiffEq

#define lorenz system
function lorenz!(du, u, p, t)
    du[1] = 10.0 * (u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

#solve and take data
prob = ODEProblem(lorenz!, [1.0, 0.0, 0.0], (0.0, 200.0))
data = solve(prob, ABM54(); dt=0.02)
data = reduce(hcat, data.u)

#determine shift length, training length and prediction length
shift = 300
train_len = 5000
predict_len = 1250

#split the data accordingly
input_data = data[:, shift:(shift + train_len - 1)]
target_data = data[:, (shift + 1):(shift + train_len)]
test_data = data[:, (shift + train_len + 1):(shift + train_len + predict_len)]
3×1250 Matrix{Float64}:
  4.2235    3.83533   3.55846   3.38262  …   9.44212  10.0034  10.5085
  1.98782   2.18137   2.43519   2.7396      12.3382   12.6956  12.8396
 25.6556   24.4858   23.3794   22.3394      24.3255   25.4374  26.6729

The call for the DeepESN works similarly to the ESN. The only difference is that the reservoir (and corresponding kwargs) can be fed as an array.

using ReservoirComputing
input_size = 3
res_size = 300
desn = DeepESN(input_size, [res_size, res_size], input_size;
    init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300),
    state_modifiers=[NLAT2, ExtendedSquare]
)
DeepESN{Tuple{StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), WeightInitializers.PartialFunction.Partial{Nothing, typeof(rand_sparse), Nothing, Base.Pairs{Symbol, Float64, Nothing, @NamedTuple{radius::Float64, sparsity::Float64}}}, typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), WeightInitializers.PartialFunction.Partial{Nothing, typeof(rand_sparse), Nothing, Base.Pairs{Symbol, Float64, Nothing, @NamedTuple{radius::Float64, sparsity::Float64}}}, typeof(scaled_rand), typeof(randn32), Float64, Static.False}}}, Tuple{Tuple{ReservoirComputing.WrappedFunction{typeof(NLAT2)}}, Tuple{ReservoirComputing.WrappedFunction{typeof(ExtendedSquare)}}}, LinearReadout{typeof(identity), Int64, Int64, typeof(rand32), typeof(rand32), Static.False, Static.True}}((StatefulLayer(ESNCell(3 => 300, use_bias=false)), StatefulLayer(ESNCell(300 => 300, use_bias=false))), ((WrappedFunction(NLAT2),), (WrappedFunction(ExtendedSquare),)), LinearReadout(300 => 3, use_bias=false, include_collect=true))

The training and prediction follow the usual framework:

using Random
Random.seed!(42)
rng = MersenneTwister(17)

ps, st = setup(rng, desn)
ps, st = train!(desn, input_data, target_data, ps, st)

output, st = predict(desn, 1250, ps, st; initialdata=test_data[:, 1])
([3.7315481316435455 3.608315744838688 … -5.483035985526737 -6.308477426237172; 2.167123090638514 2.4859666039492967 … -9.25960947617805 -10.65875954398184; 24.327152850096237 23.12235775720066 … 14.122798421569604 14.53749985725948], (cells = ((cell = (rng = Random.MersenneTwister(17, (0, 489780, 488778, 34)),), carry = ([-0.4857312893147171; 0.9661704404800712; … ; -0.7260144458359138; -0.35967542434307065;;],)), (cell = (rng = Random.MersenneTwister(17, (0, 489780, 488778, 35)),), carry = ([-0.05733742021942096; -0.525948450525435; … ; -0.283521409543962; 0.3350921266811665;;],))), states_modifiers = ((NamedTuple(),), (NamedTuple(),)), readout = NamedTuple()))

Plotting the results:

using Plots

ts = 0.0:0.02:200.0
lorenz_maxlyap = 0.9056
predict_ts = ts[(shift + train_len + 1):(shift + train_len + predict_len)]
lyap_time = (predict_ts .- predict_ts[1]) * (1 / lorenz_maxlyap)

p1 = plot(lyap_time, [test_data[1, :] output[1, :]]; label=["actual" "predicted"],
    ylabel="x(t)", linewidth=2.5, xticks=false, yticks=-15:15:15);
p2 = plot(lyap_time, [test_data[2, :] output[2, :]]; label=["actual" "predicted"],
    ylabel="y(t)", linewidth=2.5, xticks=false, yticks=-20:20:20);
p3 = plot(lyap_time, [test_data[3, :] output[3, :]]; label=["actual" "predicted"],
    ylabel="z(t)", linewidth=2.5, xlabel="max(λ)*t", yticks=10:15:40);

plot(p1, p2, p3; plot_title="Lorenz System Coordinates",
    layout=(3, 1), xtickfontsize=12, ytickfontsize=12, xguidefontsize=15,
    yguidefontsize=15,
    legendfontsize=12, titlefontsize=20)
Example block output

References

  • Gallicchio, C. and Micheli, A. (2017). Deep echo state network (deepesn): A brief survey, arXiv preprint arXiv:1712.04323.