Building a model from scratch

ReservoirComputing.jl provides utilities to build reservoir reservoir computing models from scratch. In this tutorial we are going to build an echo state network (ESN) and showcase how this custom implementation is equivalent to the provided model (minus some comfort utilities).

Using provided layers: ReservoirChain, ESNCell, and LinearReadout

The library provides a ReservoirChain, which is virtually equivalent to Lux's Chain. Passing layers, or functions, to the chain will concatenate them, and will allow the flow of the input data through the model.

To build an ESN we also need a ESNCell to provide the ESN forward pass. However, the cell is stateless, so to keep the memory of the input we need to wrap it in a StatefulLayer, which saves the internal state in the model states st and feeds it to the cell in the next step.

Finally, we need the trainable readout for the reservoir computing. The library provides LinearReadout, a dense layer the weights of which will be trained using linear regression.

Putting it all together we get the following

using ReservoirComputing

esn_scratch = ReservoirChain(
    StatefulLayer(
        ESNCell(3=>50)
    ),
    LinearReadout(50=>1)
)
ReservoirChain{@NamedTuple{layer_1::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_2::Collect, layer_3::LinearReadout{typeof(identity), Int64, Int64, typeof(rand32), typeof(rand32), Static.False, Static.True}}, Nothing}((layer_1 = StatefulLayer(ESNCell(3 => 50, use_bias=false)), layer_2 = Collection point of states, layer_3 = LinearReadout(50 => 1, use_bias=false, include_collect=true)), nothing)

Now, this implementation, elements naming aside, is completely equivalent to the following

esn = ESN(3, 50, 1)
ESN(
    reservoir = StatefulLayer(ESNCell(3 => 50, use_bias=false)),
    state_modifiers = (),
    readout = LinearReadout(50 => 1, use_bias=false, include_collect=true)
)

and we can check it initializing the two models and comparing, for instance, the weights of the input layer:

using Random
Random.seed!(43)

rng = MersenneTwister(17)
ps_s, st_s = setup(rng, esn_scratch)

rng = MersenneTwister(17)
ps, st = setup(rng, esn)

ps_s.layer_1.input_matrix == ps.reservoir.input_matrix
true

Both the models can be trained using train!, and predictions can be obtained with predict. The internal states collected for linear regression are computed by traversing the ReservoirChain, and stopping right before the LinearReadout.

Manual state collection with Collect

For more complicated models usually you would want to control when the state collection happens. In a ReservoirChain, the collection of states is controlled by the layer Collect. The role of this layer is to tell the collectstates function where to stop for state collection. All the readout layers have a include_collect=true keyword, which forces a Collect layer bvefore the readout. The model we wrote before can be written as

esn_scratch = ReservoirChain(
    StatefulLayer(
        ESNCell(3=>50)
    ),
    Collect(),
    LinearReadout(50=>1; include_collect=false)
)
ReservoirChain{@NamedTuple{layer_1::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_2::Collect, layer_3::LinearReadout{typeof(identity), Int64, Int64, typeof(rand32), typeof(rand32), Static.False, Static.False}}, Nothing}((layer_1 = StatefulLayer(ESNCell(3 => 50, use_bias=false)), layer_2 = Collection point of states, layer_3 = LinearReadout(50 => 1, use_bias=false)), nothing)

to make the collection explicit. This layer is useful in case one needs to build more complicated models such as a DeepESN. We can build a deep model in multiple ways:

deepesn_scratch = ReservoirChain(
    StatefulLayer(
        ESNCell(3=>50)
    ),
    StatefulLayer(
        ESNCell(50=>50)
    ),
    StatefulLayer(
        ESNCell(50=>50)
    ),
    Collect(),
    LinearReadout(50=>1; include_collect=false)
)
ReservoirChain{@NamedTuple{layer_1::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_2::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_3::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_4::Collect, layer_5::LinearReadout{typeof(identity), Int64, Int64, typeof(rand32), typeof(rand32), Static.False, Static.False}}, Nothing}((layer_1 = StatefulLayer(ESNCell(3 => 50, use_bias=false)), layer_2 = StatefulLayer(ESNCell(50 => 50, use_bias=false)), layer_3 = StatefulLayer(ESNCell(50 => 50, use_bias=false)), layer_4 = Collection point of states, layer_5 = LinearReadout(50 => 1, use_bias=false)), nothing)

this first approach is the one provided by default in the library through DeepESN. However, you could want the state collection to be after each cell

deepesn_scratch = ReservoirChain(
    StatefulLayer(
        ESNCell(3=>50)
    ),
    Collect(),
    StatefulLayer(
        ESNCell(50=>50)
    ),
    Collect(),
    StatefulLayer(
        ESNCell(50=>50)
    ),
    Collect(),
    LinearReadout(50=>1; include_collect=false)
)
ReservoirChain{@NamedTuple{layer_1::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_2::Collect, layer_3::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_4::Collect, layer_5::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_6::Collect, layer_7::LinearReadout{typeof(identity), Int64, Int64, typeof(rand32), typeof(rand32), Static.False, Static.False}}, Nothing}((layer_1 = StatefulLayer(ESNCell(3 => 50, use_bias=false)), layer_2 = Collection point of states, layer_3 = StatefulLayer(ESNCell(50 => 50, use_bias=false)), layer_4 = Collection point of states, layer_5 = StatefulLayer(ESNCell(50 => 50, use_bias=false)), layer_6 = Collection point of states, layer_7 = LinearReadout(50 => 1, use_bias=false)), nothing)

With this approach, the resulting state will be a concatenation of the states at each Collect point. So the resulting states for this architecture will be vector of size 150.

ps, st = setup(rng, deepesn_scratch)
states, st = collectstates(deepesn_scratch, rand(3, 300), ps, st)
size(states[:,1])
(150,)

This allows for even more complex constructions, where the state collection follows specific patterns

deepesn_scratch = ReservoirChain(
    StatefulLayer(
        ESNCell(3=>50)
    ),
    StatefulLayer(
        ESNCell(50=>50)
    ),
    Collect(),
    StatefulLayer(
        ESNCell(50=>50)
    ),
    Collect(),
    LinearReadout(50=>1; include_collect=false)
)
ReservoirChain{@NamedTuple{layer_1::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_2::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_3::Collect, layer_4::StatefulLayer{ESNCell{typeof(tanh), Int64, Int64, typeof(zeros32), typeof(rand_sparse), typeof(scaled_rand), typeof(randn32), Float64, Static.False}}, layer_5::Collect, layer_6::LinearReadout{typeof(identity), Int64, Int64, typeof(rand32), typeof(rand32), Static.False, Static.False}}, Nothing}((layer_1 = StatefulLayer(ESNCell(3 => 50, use_bias=false)), layer_2 = StatefulLayer(ESNCell(50 => 50, use_bias=false)), layer_3 = Collection point of states, layer_4 = StatefulLayer(ESNCell(50 => 50, use_bias=false)), layer_5 = Collection point of states, layer_6 = LinearReadout(50 => 1, use_bias=false)), nothing)

Here, for instance, we have a Collect after the first two cells and then one at the very end. You can see how the size of the states is now 100:

ps, st = setup(rng, deepesn_scratch)
states, st = collectstates(deepesn_scratch, rand(3, 300), ps, st)
size(states[:,1])
(100,)

Similar approaches could be leveraged, for instance, when the data show multiscale dynamics that require specific modeling approaches.