Deep Echo State Networks

In this example we showcase how to build a deep echo state network (DeepESN) following the work of (Gallicchio and Micheli, 2017). The DeepESN stacks reservoirs on top of each other, feeding the output from one into the next. In the version implemented in ReservoirComputing.jl the final state is the state used for training.

Lorenz Example

We are going to reuse the Lorenz data used in the Lorenz System Forecasting example.

using OrdinaryDiffEq

#define lorenz system
function lorenz!(du, u, p, t)
    du[1] = 10.0 * (u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

#solve and take data
prob = ODEProblem(lorenz!, [1.0, 0.0, 0.0], (0.0, 200.0))
data = solve(prob, ABM54(); dt=0.02)
data = reduce(hcat, data.u)

#determine shift length, training length and prediction length
shift = 300
train_len = 5000
predict_len = 1250

#split the data accordingly
input_data = data[:, shift:(shift + train_len - 1)]
target_data = data[:, (shift + 1):(shift + train_len)]
test_data = data[:, (shift + train_len + 1):(shift + train_len + predict_len)]
3×1250 Matrix{Float64}:
  5.97702   6.52739   7.12578   7.77037  …  -1.11904  -1.33419  -1.5963
  8.6124    9.39855  10.2371   11.0985      -2.09217  -2.51955  -3.04071
 20.2622   20.3067   20.5579   21.0389       8.8609    8.45571   8.09576

The call for the DeepESN works similarly to the ESN. The only difference is that the reservoir (and corresponding kwargs) can be fed as an array.

using ReservoirComputing
input_size = 3
res_size = 300
desn = DeepESN(input_size, [res_size, res_size], input_size;
    init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300),
    state_modifiers=[NLAT2, ExtendedSquare]
)
DeepESN{Tuple{StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), WeightInitializers.PartialFunction.Partial{Nothing, typeof(rand_sparse), Nothing, Base.Pairs{Symbol, Float64, Nothing, @NamedTuple{radius::Float64, sparsity::Float64}}}, typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), WeightInitializers.PartialFunction.Partial{Nothing, typeof(rand_sparse), Nothing, Base.Pairs{Symbol, Float64, Nothing, @NamedTuple{radius::Float64, sparsity::Float64}}}, typeof(scaled_rand), typeof(randn32), Float64, Static.False}}}, Tuple{Tuple{ReservoirComputing.WrappedFunction{typeof(NLAT2)}}, Tuple{ReservoirComputing.WrappedFunction{typeof(ExtendedSquare)}}}, LinearReadout{typeof(identity), Int64, Int64, typeof(rand32), typeof(rand32), Static.False, Static.True}}((StatefulLayer(ESNCell(3 => 300, use_bias=false)), StatefulLayer(ESNCell(300 => 300, use_bias=false))), ((WrappedFunction(NLAT2),), (WrappedFunction(ExtendedSquare),)), LinearReadout(300 => 3, use_bias=false, include_collect=true))

The training and prediction follow the usual framework:

using Random
Random.seed!(42)
rng = MersenneTwister(17)

ps, st = setup(rng, desn)
ps, st = train!(desn, input_data, target_data, ps, st)

output, st = predict(desn, 1250, ps, st; initialdata=test_data[:, 1])
([6.254159489659225 7.059981632742137 … 2.3471624347275206 2.3597394223840196; 9.218406603114168 10.095869405816137 … 2.2483702951731814 2.5986224103737214; 20.42918029204162 20.929490009650245 … 19.98367157046807 19.059020241787454], (cells = ((cell = (rng = Random.MersenneTwister(17, (0, 489780, 488778, 34)),), carry = ([0.8463959408539935; 0.8923246404915989; … ; 0.21237730621929385; -0.4651342184216739;;],)), (cell = (rng = Random.MersenneTwister(17, (0, 489780, 488778, 35)),), carry = ([0.2560703110860431; 0.5178670594631571; … ; -0.5862748682228631; -0.5975015771000671;;],))), states_modifiers = ((NamedTuple(),), (NamedTuple(),)), readout = NamedTuple()))

Plotting the results:

using Plots

ts = 0.0:0.02:200.0
lorenz_maxlyap = 0.9056
predict_ts = ts[(shift + train_len + 1):(shift + train_len + predict_len)]
lyap_time = (predict_ts .- predict_ts[1]) * (1 / lorenz_maxlyap)

p1 = plot(lyap_time, [test_data[1, :] output[1, :]]; label=["actual" "predicted"],
    ylabel="x(t)", linewidth=2.5, xticks=false, yticks=-15:15:15);
p2 = plot(lyap_time, [test_data[2, :] output[2, :]]; label=["actual" "predicted"],
    ylabel="y(t)", linewidth=2.5, xticks=false, yticks=-20:20:20);
p3 = plot(lyap_time, [test_data[3, :] output[3, :]]; label=["actual" "predicted"],
    ylabel="z(t)", linewidth=2.5, xlabel="max(λ)*t", yticks=10:15:40);

plot(p1, p2, p3; plot_title="Lorenz System Coordinates",
    layout=(3, 1), xtickfontsize=12, ytickfontsize=12, xguidefontsize=15,
    yguidefontsize=15,
    legendfontsize=12, titlefontsize=20)
Example block output

References

  • Gallicchio, C. and Micheli, A. (2017). Deep echo state network (deepesn): A brief survey, arXiv preprint arXiv:1712.04323.