Deep Echo State Networks
In this example we showcase how to build a deep echo state network (DeepESN) following the work of (Gallicchio and Micheli, 2017). The DeepESN stacks reservoirs on top of each other, feeding the output from one into the next. In the version implemented in ReservoirComputing.jl the final state is the state used for training.
Lorenz Example
We are going to reuse the Lorenz data used in the Lorenz System Forecasting example.
using OrdinaryDiffEq
#define lorenz system
function lorenz!(du, u, p, t)
du[1] = 10.0 * (u[2] - u[1])
du[2] = u[1] * (28.0 - u[3]) - u[2]
du[3] = u[1] * u[2] - (8 / 3) * u[3]
end
#solve and take data
prob = ODEProblem(lorenz!, [1.0, 0.0, 0.0], (0.0, 200.0))
data = solve(prob, ABM54(); dt=0.02)
data = reduce(hcat, data.u)
#determine shift length, training length and prediction length
shift = 300
train_len = 5000
predict_len = 1250
#split the data accordingly
input_data = data[:, shift:(shift + train_len - 1)]
target_data = data[:, (shift + 1):(shift + train_len)]
test_data = data[:, (shift + train_len + 1):(shift + train_len + predict_len)]3×1250 Matrix{Float64}:
5.97702 6.52739 7.12578 7.77037 … -1.11904 -1.33419 -1.5963
8.6124 9.39855 10.2371 11.0985 -2.09217 -2.51955 -3.04071
20.2622 20.3067 20.5579 21.0389 8.8609 8.45571 8.09576The call for the DeepESN works similarly to the ESN. The only difference is that the reservoir (and corresponding kwargs) can be fed as an array.
using ReservoirComputing
input_size = 3
res_size = 300
desn = DeepESN(input_size, [res_size, res_size], input_size;
init_reservoir=rand_sparse(; radius=1.2, sparsity=6/300),
state_modifiers=[NLAT2, ExtendedSquare]
)DeepESN{Tuple{StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), WeightInitializers.PartialFunction.Partial{Nothing, typeof(rand_sparse), Nothing, Base.Pairs{Symbol, Float64, Nothing, @NamedTuple{radius::Float64, sparsity::Float64}}}, typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), WeightInitializers.PartialFunction.Partial{Nothing, typeof(rand_sparse), Nothing, Base.Pairs{Symbol, Float64, Nothing, @NamedTuple{radius::Float64, sparsity::Float64}}}, typeof(scaled_rand), typeof(randn32), Float64, Static.False}}}, Tuple{Tuple{ReservoirComputing.WrappedFunction{typeof(NLAT2)}}, Tuple{ReservoirComputing.WrappedFunction{typeof(ExtendedSquare)}}}, LinearReadout{typeof(identity), Int64, Int64, typeof(rand32), typeof(rand32), Static.False, Static.True}}((StatefulLayer(ESNCell(3 => 300, use_bias=false)), StatefulLayer(ESNCell(300 => 300, use_bias=false))), ((WrappedFunction(NLAT2),), (WrappedFunction(ExtendedSquare),)), LinearReadout(300 => 3, use_bias=false, include_collect=true))The training and prediction follow the usual framework:
using Random
Random.seed!(42)
rng = MersenneTwister(17)
ps, st = setup(rng, desn)
ps, st = train!(desn, input_data, target_data, ps, st)
output, st = predict(desn, 1250, ps, st; initialdata=test_data[:, 1])([6.254159489659225 7.059981632742137 … 2.3471624347275206 2.3597394223840196; 9.218406603114168 10.095869405816137 … 2.2483702951731814 2.5986224103737214; 20.42918029204162 20.929490009650245 … 19.98367157046807 19.059020241787454], (cells = ((cell = (rng = Random.MersenneTwister(17, (0, 489780, 488778, 34)),), carry = ([0.8463959408539935; 0.8923246404915989; … ; 0.21237730621929385; -0.4651342184216739;;],)), (cell = (rng = Random.MersenneTwister(17, (0, 489780, 488778, 35)),), carry = ([0.2560703110860431; 0.5178670594631571; … ; -0.5862748682228631; -0.5975015771000671;;],))), states_modifiers = ((NamedTuple(),), (NamedTuple(),)), readout = NamedTuple()))Plotting the results:
using Plots
ts = 0.0:0.02:200.0
lorenz_maxlyap = 0.9056
predict_ts = ts[(shift + train_len + 1):(shift + train_len + predict_len)]
lyap_time = (predict_ts .- predict_ts[1]) * (1 / lorenz_maxlyap)
p1 = plot(lyap_time, [test_data[1, :] output[1, :]]; label=["actual" "predicted"],
ylabel="x(t)", linewidth=2.5, xticks=false, yticks=-15:15:15);
p2 = plot(lyap_time, [test_data[2, :] output[2, :]]; label=["actual" "predicted"],
ylabel="y(t)", linewidth=2.5, xticks=false, yticks=-20:20:20);
p3 = plot(lyap_time, [test_data[3, :] output[3, :]]; label=["actual" "predicted"],
ylabel="z(t)", linewidth=2.5, xlabel="max(λ)*t", yticks=10:15:40);
plot(p1, p2, p3; plot_title="Lorenz System Coordinates",
layout=(3, 1), xtickfontsize=12, ytickfontsize=12, xguidefontsize=15,
yguidefontsize=15,
legendfontsize=12, titlefontsize=20)References
- Gallicchio, C. and Micheli, A. (2017). Deep echo state network (deepesn): A brief survey, arXiv preprint arXiv:1712.04323.