Linear State Space Examples

This tutorial provides describes the support for linear and linear gaussian state space models.

At this point, the package only supports linear time-invariant models without a separate p vector. The canonical form of the linear model is

\[u_{n+1} = A u_n + B w_{n+1}\]

with

\[z_n = C u_n + v_n\]

and optionally $v_n \sim N(0, D)$ and $w_{n+1} \sim N(0,I)$. If you pass noise into the solver, it no longer needs to be Gaussian. More generally, support could be added for $u_{n+1} = A(p,n) u_n + B(p,n) w_{n+1}$ where $p$ is a vector of differentiable parameters, and the $A$ and $B$ are potentially matrix-free operators.

Simulating a Linear (and Time-Invariant) State Space Model

Creating a LinearStateSpaceProblem and simulating it for a simple, linear equation.

using DifferenceEquations, LinearAlgebra, Distributions, Random, Plots, DataFrames, Zygote
A = [0.95 6.2;
     0.0  0.2]
B = [0.0; 0.01;;] # matrix
C = [0.09 0.67;
     1.00 0.00]
D = [0.1, 0.1] # diagonal observation noise
u0 = zeros(2)
T = 10

prob = LinearStateSpaceProblem(A, B, u0, (0, T); C, observables_noise = D, syms = [:a, :b])
sol = solve(prob)
retcode: Success
Interpolation: Piecewise constant interpolation
t: 0:10
u: 11-element Vector{Vector{Float64}}:
 [0.0, 0.0]
 [0.0, -0.0018955008044782915]
 [-0.011752104987765409, 0.001975062708655071]
 [0.0010808890552843015, 0.010838927462609196]
 [0.0682281948706971, -0.005045948985828016]
 [0.03353190141502854, 0.006904686748216834]
 [0.07466436418322149, -0.004910453395279099]
 [0.04048633492333, -0.011688871349785876]
 [-0.03400898419150893, 0.0026179280637263695]
 [-0.016077380986829995, -0.0053803072146355385]
 [-0.04863141666822884, 0.007840386028350653]

The u vector of the simulated solution can be plotted using standard recipes, including use of the optional syms. See SciML docs for more options.

plot(sol)

By default the solution provides an interface to access the simulated u. That is, sol.u[...] = sol[...],

sol[2]
2-element Vector{Float64}:
  0.0
 -0.0018955008044782915

Or to get the first element of the last step

sol[end][1] #first element of last step
-0.04863141666822884

Finally, to extract the full vector

@show sol[2,:];  # whole second vector
11-element Vector{Float64}:
  0.0
 -0.0018955008044782915
  0.001975062708655071
  0.010838927462609196
 -0.005045948985828016
  0.006904686748216834
 -0.004910453395279099
 -0.011688871349785876
  0.0026179280637263695
 -0.0053803072146355385
  0.007840386028350653

The results for all of sol.u can be loaded in a dataframe, where the column names will be the (optionally) provided symbols.

df = DataFrame(sol)
11×3 DataFrame
Rowtimestampab
Int64Float64Float64
100.00.0
210.0-0.0018955
32-0.01175210.00197506
430.001080890.0108389
540.0682282-0.00504595
650.03353190.00690469
760.0746644-0.00491045
870.0404863-0.0116889
98-0.0340090.00261793
109-0.0160774-0.00538031
1110-0.04863140.00784039

Other results, such as the simulated noise and observables can be extracted from the solution

sol.z # observables
11-element Vector{Vector{Float64}}:
 [-0.42362758984593496, -0.18915896622033826]
 [0.056426818193605344, 0.0660606776797671]
 [0.35166309834204573, -0.39786622148073036]
 [0.02233069374320561, -0.0024599762677543144]
 [0.8245642137097527, -0.1323868799943248]
 [0.3049565742681066, 0.039783150826104736]
 [-0.27270196682461506, 0.7834822960784472]
 [-0.16482833346905001, 0.267826721269844]
 [-0.0035392487852325383, 0.1337287494647419]
 [-0.048826632584235835, -0.22552264014368534]
 [0.28127579932039015, -0.12684506075995572]
sol.W # Simulated Noise
1×10 Matrix{Float64}:
 -0.18955  0.235416  1.04439  -0.721373  …  0.49557  -0.590389  0.891645

We can also solve the model passing in fixed noise, which will be useful for joint likelihoods. First lets extract the noise from the previous solution, then rerun the simulation but with a different initial value

noise = sol.W
u0_2 = [0.1, 0.0]
prob2 = LinearStateSpaceProblem(A, B, u0_2, (0, T); C, observables_noise = D, syms = [:a, :b], noise)
sol2 = solve(prob2)
plot(sol2)

To construct an IRF we can take the model and perturb just the first element of the noise,

function irf(A, B, C, T = 20)
    noise = Matrix([1.0; zeros(T-1)]')
    problem = LinearStateSpaceProblem(A, B, zeros(2), (0, T); C, noise, syms = [:a, :b])
    return solve(problem)
end
plot(irf(A, B, C))

Lets find the 2nd observable at the end of the IRF.

function last_observable_irf(A, B, C)
    sol = irf(A, B, C)
    return sol.z[end][2]  # return 2nd argument of last observable
end
last_observable_irf(A, B, C)
0.03119456447624772

But everything in this package is differentiable. Lets differentiate the observable of the IRF with respect to all of the parameters using Zygote.jl,

gradient(last_observable_irf, A, B, C)  # calculates gradient wrt all arguments
([0.5822985368900442 0.0050313813671367304; 4.469834483178462 0.041592752634585235], [0.37735360253530714; 3.119456447624773;;], [0.0 0.0; 0.03119456447624772 5.242880000000006e-16])

Gradients of other model elements (e.g. .u are also possible. With this in mind, lets find the gradient of the mean of the 1st element of the IRF of the solution with respect to a particular noise vector.

function mean_u_1(A, B, C, noise, u0, T)
    problem = LinearStateSpaceProblem(A, B, u0, (0, T); noise, syms = [:a, :b])
    sol = solve(problem)
    u = sol.u # see issue #75 workaround
    # can have nontrivial functions and even non-mutating loops
    return mean( u[i][1] for i in 1:T)
end
u0 = [0.0, 0.0]
noise = sol.W # from simulation above
mean_u_1(A, B, C, noise, u0, T)
# dropping a few arguments from derivative
gradient((noise, u0)-> mean_u_1(A, B, C, noise, u0, T), noise, u0)
([0.05079876954953124 0.045314515146875 … 0.0 0.0], [0.702526121523242, 5.600882710405467])

Simulating Ensembles and Fixing Noise

If you pass in a distribution for the initial condition, it will draw an initial condition. Below we will simulate from a deterministic evolution equation and without any observation noise.

using Distributions, DiffEqBase
u0 = MvNormal([1.0 0.1; 0.1 1.0])  # mean zero initial conditions
prob = LinearStateSpaceProblem(A, nothing, u0, (0, T); C)
sol = solve(prob)
plot(sol)

With this, we can simulate an ensemble of solutions from different initial conditions (and we will turn back on the noise). The EnsembleSummary calculates a set of quantiles by default.

T = 10
trajectories = 50
prob = LinearStateSpaceProblem(A, B, u0, (0, T); C)
sol = solve(EnsembleProblem(prob), DirectIteration(), EnsembleThreads(); trajectories)
summ = EnsembleSummary(sol)  #calculate summarize statistics from the
plot(summ)  # shows quantiles by default

Observables and Marginal Likelihood using a Kalman Filter

If you provide observables and provide a distribution for the observables_noise then the model can provide a calculation of the likelihood.

The simplest case is if you use a gaussian prior and have gaussian observation noise. First, lets simulate some data with included observation noise. If passing in a matrix or vector, the observables_noise argument is intended to be the cholesky of the covariance matrix. At this point, only diagonal observation noise is allowed.

u0 = MvNormal([1.0 0.1; 0.1 1.0])  # draw from mean zero initial conditions
T = 10
prob = LinearStateSpaceProblem(A, B, u0, (0, T); C, observables_noise = D, syms = [:a, :b])
sol = solve(prob)
sol.z # simulated observables with observation noise
11-element Vector{Vector{Float64}}:
 [0.06286778305703349, 1.100543591813311]
 [-0.14091782937357641, 0.13640316926939153]
 [0.08822906988583239, 0.24417811833626452]
 [0.42915744886974355, 0.21301632294425052]
 [0.004467793307069687, -0.2375589827639179]
 [0.3258963737530623, 0.16269605789768743]
 [-0.13577972310249, 0.5370189250728027]
 [-0.04900078748109346, -0.1953087952983826]
 [-0.24005982270536014, 0.32250436940321126]
 [0.30285303515782946, 0.15812695881086733]
 [0.13866311898691702, 0.344998891660888]

Next we will find the log likelihood of these simulated observables using the u0 as a prior and with the true parameters.

The new arguments we pass to the problem creation are u0_prior_variance, u0_prior_mean, and observables. The u0 is ignored for the filtering problem but must match the size. The KalmanFilter() argument to the solve is unnecessary since it can select it manually given the priors and observables.

Note

The timing convention is such that observables are expected to match the predictions starting at the second time period. As the likelihood of the first element u0 comes from a prior, the observables start at the next element, and hence the observables and noise sequences should be 1 less than the tspan

observables = hcat(sol.z...)  # Observables required to be matrix.  Issue #55
observables = observables[:, 2:end] # see note above on likelihood and timing
noise = copy(sol.W) # save for later
u0_prior_mean = [0.0, 0.0]
# use covariance of distribution we drew from
u0_prior_var = cov(u0)

prob = LinearStateSpaceProblem(A, B, u0, (0, size(observables,2)); C, observables, observables_noise = D, syms = [:a, :b], u0_prior_var, u0_prior_mean)
sol = solve(prob, KalmanFilter())
# plot(sol) The `u` is the sequence of posterior means.
sol.logpdf
-4.906890193617707

Hence the logpdf provides the log likelihood marginalizing out the latent noise variables.

As before, we can differentiate the kalman filter itself.

function kalman_likelihood(A, B, C, D, u0_prior_mean, u0_prior_var, observables)
    prob = LinearStateSpaceProblem(A, B, u0, (0, size(observables,2)); C, observables, observables_noise = D, syms = [:a, :b], u0_prior_var, u0_prior_mean)
    return solve(prob).logpdf
end
kalman_likelihood(A, B, C, D, u0_prior_mean, u0_prior_var, observables)
# Find the gradient wrt the A, B, C and priors variance.
gradient((A, B, C, u0_prior_var) -> kalman_likelihood(A, B, C, D, u0_prior_mean, u0_prior_var, observables), A, B, C, u0_prior_var)
([-2.3077345470043165 -0.22838526931597572; 2.755561685008794 -2.5210142215578633], [-4.526390337304267; -44.3334206365972;;], [0.8576148941773951 -0.006722687682856221; -1.7397574214645317 0.029952313369386196], [-0.11965436036168108 4.663708588760336; -4.736455493824013 -0.4849419868171303])
Note

Some of the gradients, such as those for observables, have not been implemented so test carefully. This is a general theme with gradients and Zygote.jl in general. Your best friend in this process is the spectacular ChainRulesTestUtils.jl package. See test_rrule usage in the linear unit tests.

Joint Likelihood with Noise

A key application of these methods is to find the joint likelihood of the latent variables (i.e., the noise) and the model definition.

The actual calculation of the likelihood is trivial in that case, and just requires iteration of the linear system while accumulating the likelihood given the observation noise.

Crucially, the differentiability with respect to the high-dimensional noise vector enables gradient-based sampling and estimation methods which would otherwise be infeasible.

function joint_likelihood(A, B, C, D, u0, noise, observables)
    prob = LinearStateSpaceProblem(A, B, u0, (0, size(observables,2)); C, observables, observables_noise = D, noise)
    return solve(prob).logpdf
end
u0 = [0.0, 0.0]
joint_likelihood(A, B, C, D, u0, noise, observables)
-1.5905832731053504

And as always, this can be differentiated with respect to the state-space matrices and the noise. Choosing a few parameters,

gradient((A, u0, noise) -> joint_likelihood(A, B, C, D, u0, noise, observables), A, u0, noise)
([-0.1575097797315329 -0.011430729777281462; -0.5430202164004418 -0.11254051848736213], [12.459129692000685, 99.57317431961539], [0.9130479743805457 0.7362581150012636 … 0.18077883896095467 0.007951681013155024])

Composition of State Space Models and AD

While the above gradients have been with respect to the full state space objects A, B, etc. those themselves could be generated through a separate procedure and the whole object differentiated. For example, lets repeat the above examples where we generate the A matrix from some sort of deep parameters.

First we will generate some observations with a generate_model proxy–-which could be replaced with something more complicated but still differentiable

function generate_model(β)
    A = [β 6.2;
        0.0  0.2]
    B = Matrix([0.0  0.001]') # [0.0; 0.001;;] gives a zygote bug
    C = [0.09 0.67;
        1.00 0.00]
    D = [0.01, 0.01]
    return (;A,B,C,D)
end

function simulate_model(β, u0;T = 200)
    mod = generate_model(β)
    prob = LinearStateSpaceProblem(mod.A, mod.B, u0, (0, T); mod.C, observables_noise = mod.D)
    sol = solve(prob) # simulates
    observables = hcat(sol.z...)
    observables = observables[:, 2:end] # see note above on likelihood and timing
    return observables, sol.W
end

# Fix a "pseudo-true" and generate noise and observables
β = 0.95
u0 = [0.0, 0.0]
observables, noise = simulate_model(β, u0)
([0.020863499840778826 -0.12830380826752583 … 0.04364188449063528 -0.01770982577302761; 0.24439498960513137 -0.08899728427211082 … -0.007301932128016523 -0.013726146053351866], [-1.465953984116615 -1.4572821173921953 … -0.02340217277169572 0.5167251189760337])

Next, we will evaluate the marginal likelihood using the kalman filter for a particular β value,

function kalman_model_likelihood(β, u0_prior_mean, u0_prior_var, observables)
    mod = generate_model(β) # generate model from structural parameters
    prob = LinearStateSpaceProblem(mod.A, mod.B, u0, (0, size(observables,2)); mod.C, observables,      observables_noise = mod.D, u0_prior_var, u0_prior_mean)
    return solve(prob).logpdf
end
u0_prior_mean = [0.0, 0.0]
u0_prior_var = [1e-10 0.0;
                0.0 1e-10]  # starting with degenerate prior
kalman_model_likelihood(β, u0_prior_mean, u0_prior_var, observables)
359.2240604868709

Given the observation error we would not expect the pseudo-true to exactly maximimize the log likelihood. To show this, we can optimize it using using the Optim package and using a gradient-based optimization routine

using Optimization, OptimizationOptimJL
# Create a function to minimize only of β and use Zygote based gradients
kalman_objective(β,p) = -kalman_model_likelihood(β, u0_prior_mean, u0_prior_var, observables)
kalman_objective(0.95, nothing)
gradient(β ->kalman_objective(β, nothing),β) # Verifying it can be differentiated


optf = OptimizationFunction(kalman_objective, Optimization.AutoZygote())
β0 = [0.91] # start off of the pseudotrue
optprob = OptimizationProblem(optf, β0)
optsol = solve(optprob,LBFGS())  # reverse-mode AD is overkill here
u: 1-element Vector{Float64}:
 -0.9362647715018005

In this way, this package composes with others such as DifferentiableStateSpaceModels.jl which take a set of structural parameters and an expectational difference equation and generate a state-space model.

Similarly, we can find the joint likelihood for a particular β value and noise. Here we will add in prior. Some form of a prior or regularization is generally necessary for these sorts of nonlinear models.

function joint_model_posterior(β, u0, noise, observables, noise_prior, β_prior)
    mod = generate_model(β) # generate model from structural parameters
    prob = LinearStateSpaceProblem(mod.A, mod.B, u0, (0, size(observables,2)); mod.C, observables,      observables_noise = mod.D, noise)
    return solve(prob).logpdf + sum(logpdf.(noise_prior, noise)) + logpdf(β_prior, β) # posterior
end
u0 = [0.0, 0.0]
noise_prior = Normal(0.0, 1.0)
β_prior = Normal(β, 0.03) # prior local to the true value
joint_model_posterior(β, u0, noise, observables, noise_prior, β_prior)
61.53379388562644

Which we can turn into a differntiable objective adding in a prior on the noise

joint_model_objective(x, p) = -joint_model_posterior(x[1], u0, Matrix(x[2:end]'), observables, noise_prior, β_prior) # extract noise and parameeter from vector
x0 = vcat([0.95], noise[1,:])  # starting at the true noise
joint_model_objective(x0, nothing)
gradient(x ->joint_model_objective(x, nothing),x0) # Verifying it can be differentiated

# optimize
optf = OptimizationFunction(joint_model_objective, Optimization.AutoZygote())
optprob = OptimizationProblem(optf, x0)
optsol = solve(optprob,LBFGS())
u: 201-element Vector{Float64}:
  0.9666999986154561
 -0.12597980003160084
 -0.06341423105188516
 -0.0007579332857906926
  0.012857378713502165
  0.0026598907296594082
  0.0694388917915022
  0.15357828484481068
  0.1027054893422976
  0.1694918041843642
  ⋮
  0.0803946695647846
  0.09062131641281987
  0.054881966401640545
  0.04765289970929892
  0.005625778916593516
 -0.051481142266637754
 -0.017707218473049232
 -0.009521627777133414
 -0.0012122081480495095

This "solves" the problem relatively quickly, despite the high-dimensionality. However, from a statistics perspective note that this last optimization process does not do especially well in recovering the pseudotrue if you increase the prior variance on the β parameter. Maximizing the posterior is usually the wrong thing to do in high-dimensions because the mode is not a typical set.

Caveats on Gradients and Performance

A few notes on performance and gradients:

  1. As this is using reverse-mode AD it will be efficient for fairly large systems as long as the ultimate value of your differentiable program. With a little extra work and unit tests, it could support structured matrices/etc. as well.
  2. Getting to much higher scales, where the A,B,C,D are so large that matrix-free operators is necessary, is feasible but will require generalizing those to LinearOperators. This would be reasonably easy for the joint likelihood and feasible but possible for the Kalman filter
  3. At this point, there is no support for forward-mode auto-differentiation. For smaller systems with a kalman filter, this should dominate the alternatives, and efficient forward-mode AD rules for the kalman filter exist (see the supplementary materials in the the Differentiable State Space Models paper). However, it would be a significant amount of work to add end-to-end support and fulfill standard SciML interfaces, and perhaps waiting for Enzyme or similar AD systems that provide both forward/reverse/mixed mode makes sense.
  4. Forward-mode AD is likely inappropriate for the joint-likelihood based models since the dimensionality of the noise is always large.
  5. The gradient rules are written using ChainRules.jl so in theory they will work with any supporting AD. In practice, though, Zygote is the most tested and other systems have inconsistent support on Julia at this time.