Deep Echo State Networks
In this example we showcase how to build a deep echo state network (DeepESN) following the work of (Gallicchio and Micheli, 2017). The DeepESN stacks reservoirs on top of each other, feeding the output from one into the next. In the version implemented in ReservoirComputing.jl the final state is the state used for training.
Lorenz Example
We are going to reuse the Lorenz data used in the Lorenz System Forecasting example.
using OrdinaryDiffEq
#define lorenz system
function lorenz!(du, u, p, t)
du[1] = 10.0 * (u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end
#solve and take data
prob = ODEProblem(lorenz!, [1.0, 0.0, 0.0], (0.0, 200.0))
data = solve(prob, ABM54(); dt=0.02)
data = reduce(hcat, data.u)
#determine shift length, training length and prediction length
shift = 300
train_len = 5000
predict_len = 1250
#split the data accordingly
input_data = data[:, shift:(shift + train_len - 1)]
target_data = data[:, (shift + 1):(shift + train_len)]
test_data = data[:, (shift + train_len + 1):(shift + train_len + predict_len)]3×1250 Matrix{Float64}:
4.2235 3.83533 3.55846 3.38262 … 9.44212 10.0034 10.5085
1.98782 2.18137 2.43519 2.7396 12.3382 12.6956 12.8396
25.6556 24.4858 23.3794 22.3394 24.3255 25.4374 26.6729The call for the DeepESN works similarly to the ESN. The only difference is that the reservoir (and corresponding kwargs) can be fed as an array.
using ReservoirComputing
input_size = 3
res_size = 300
desn = DeepESN(input_size, [res_size, res_size], input_size;
init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300),
state_modifiers=[NLAT2, ExtendedSquare]
)DeepESN{Tuple{StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), WeightInitializers.PartialFunction.Partial{Nothing, typeof(rand_sparse), Nothing, Base.Pairs{Symbol, Float64, Nothing, @NamedTuple{radius::Float64, sparsity::Float64}}}, typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), WeightInitializers.PartialFunction.Partial{Nothing, typeof(rand_sparse), Nothing, Base.Pairs{Symbol, Float64, Nothing, @NamedTuple{radius::Float64, sparsity::Float64}}}, typeof(scaled_rand), typeof(randn32), Float64, Static.False}}}, Tuple{Tuple{ReservoirComputing.WrappedFunction{typeof(NLAT2)}}, Tuple{ReservoirComputing.WrappedFunction{typeof(ExtendedSquare)}}}, LinearReadout{typeof(identity), Int64, Int64, typeof(rand32), typeof(rand32), Static.False, Static.True}}((StatefulLayer(ESNCell(3 => 300, use_bias=false)), StatefulLayer(ESNCell(300 => 300, use_bias=false))), ((WrappedFunction(NLAT2),), (WrappedFunction(ExtendedSquare),)), LinearReadout(300 => 3, use_bias=false, include_collect=true))The training and prediction follow the usual framework:
using Random
Random.seed!(42)
rng = MersenneTwister(17)
ps, st = setup(rng, desn)
ps, st = train!(desn, input_data, target_data, ps, st)
output, st = predict(desn, 1250, ps, st; initialdata=test_data[:, 1])([3.7315481316435455 3.608315744838688 … -5.483035985526737 -6.308477426237172; 2.167123090638514 2.4859666039492967 … -9.25960947617805 -10.65875954398184; 24.327152850096237 23.12235775720066 … 14.122798421569604 14.53749985725948], (cells = ((cell = (rng = Random.MersenneTwister(17, (0, 489780, 488778, 34)),), carry = ([-0.4857312893147171; 0.9661704404800712; … ; -0.7260144458359138; -0.35967542434307065;;],)), (cell = (rng = Random.MersenneTwister(17, (0, 489780, 488778, 35)),), carry = ([-0.05733742021942096; -0.525948450525435; … ; -0.283521409543962; 0.3350921266811665;;],))), states_modifiers = ((NamedTuple(),), (NamedTuple(),)), readout = NamedTuple()))Plotting the results:
using Plots
ts = 0.0:0.02:200.0
lorenz_maxlyap = 0.9056
predict_ts = ts[(shift + train_len + 1):(shift + train_len + predict_len)]
lyap_time = (predict_ts .- predict_ts[1]) * (1 / lorenz_maxlyap)
p1 = plot(lyap_time, [test_data[1, :] output[1, :]]; label=["actual" "predicted"],
ylabel="x(t)", linewidth=2.5, xticks=false, yticks=-15:15:15);
p2 = plot(lyap_time, [test_data[2, :] output[2, :]]; label=["actual" "predicted"],
ylabel="y(t)", linewidth=2.5, xticks=false, yticks=-20:20:20);
p3 = plot(lyap_time, [test_data[3, :] output[3, :]]; label=["actual" "predicted"],
ylabel="z(t)", linewidth=2.5, xlabel="max(λ)*t", yticks=10:15:40);
plot(p1, p2, p3; plot_title="Lorenz System Coordinates",
layout=(3, 1), xtickfontsize=12, ytickfontsize=12, xguidefontsize=15,
yguidefontsize=15,
legendfontsize=12, titlefontsize=20)References
- Gallicchio, C. and Micheli, A. (2017). Deep echo state network (deepesn): A brief survey, arXiv preprint arXiv:1712.04323.