Generic Callbacks

The StateSpaceProblem type provides a fully generic interface for discrete-time state-space models. Instead of specifying matrices, you supply callback functions for the state transition and observation equations. This is useful for nonlinear models, time-varying dynamics, or any structure that does not fit the linear or quadratic templates.

Callback Signatures

The two callbacks follow the "bang-bang" convention used throughout SciML: for mutable arrays, mutate the output buffer in place and return it; for immutable arrays (e.g., SVector), ignore the buffer and return a new value.

Transition function:f!!(x_next, x, w, p, t) -> x_next

  • x_next: pre-allocated output buffer (mutate in place for mutable arrays)
  • x: current state
  • w: noise shock at this step (or nothing if n_shocks = 0)
  • p: parameters passed to the problem
  • t: integer time index (0-based)

Observation function:g!!(y, x, p, t) -> y

  • y: pre-allocated output buffer
  • x: current state
  • p: parameters
  • t: integer time index (0-based)

Pass nothing for the observation function if no observations are needed.

Example: Linear Model via Callbacks

We can reproduce the behavior of LinearStateSpaceProblem using generic callbacks. This verifies the interface and demonstrates the pattern.

using DifferenceEquations, LinearAlgebra, Random

A = [0.95 6.2; 0.0 0.2]
B = [0.0; 0.01;;]
C = [0.09 0.67; 1.00 0.00]

linear_f!! = (x_next, x, w, p, t) -> begin
    mul!(x_next, p.A, x)
    mul!(x_next, p.B, w, 1.0, 1.0)
    return x_next
end
linear_g!! = (y, x, p, t) -> begin
    mul!(y, p.C, x)
    return y
end
p = (; A, B, C)
u0 = zeros(2)
T = 10

prob = StateSpaceProblem(linear_f!!, linear_g!!, u0, (0, T), p;
    n_shocks = 1, n_obs = 2, syms = (:a, :b))
sol = solve(prob)
retcode: Success
Interpolation: Piecewise constant interpolation
t: 0:1:10
u: 11-element Vector{Vector{Float64}}:
 [0.0, 0.0]
 [0.0, -0.001814209755708581]
 [-0.011248100485393202, 0.000370392834315229]
 [-0.008389259888369121, -0.008923007079156962]
 [-0.06329244078472383, 0.003927223157563366]
 [-0.03577903516859477, 0.001180694173184593]
 [-0.026669779536420555, 0.0009464212701223977]
 [-0.01946847868484066, -0.009589546259799365]
 [-0.07795024156135469, 0.0003416764506392804]
 [-0.07193433548932342, 0.0004483876493976216]
 [-0.065557615288592, 0.002125268964511902]

The solution has the same structure as the linear case:

sol.u  # state trajectory, Vector{Vector}
11-element Vector{Vector{Float64}}:
 [0.0, 0.0]
 [0.0, -0.001814209755708581]
 [-0.011248100485393202, 0.000370392834315229]
 [-0.008389259888369121, -0.008923007079156962]
 [-0.06329244078472383, 0.003927223157563366]
 [-0.03577903516859477, 0.001180694173184593]
 [-0.026669779536420555, 0.0009464212701223977]
 [-0.01946847868484066, -0.009589546259799365]
 [-0.07795024156135469, 0.0003416764506392804]
 [-0.07193433548932342, 0.0004483876493976216]
 [-0.065557615288592, 0.002125268964511902]
sol.z  # observations, Vector{Vector}
11-element Vector{Vector{Float64}}:
 [0.0, 0.0]
 [-0.0012155205363247493, 0.0]
 [-0.0007641658446941847, -0.011248100485393202]
 [-0.006733448132988386, -0.008389259888369121]
 [-0.0030650801550576898, -0.06329244078472383]
 [-0.002429048069139852, -0.03577903516859477]
 [-0.0017661779072958433, -0.026669779536420555]
 [-0.008177159075701235, -0.01946847868484066]
 [-0.006786598518593604, -0.07795024156135469]
 [-0.006173670468942701, -0.07193433548932342]
 [-0.004476255169750306, -0.065557615288592]

We can verify this matches the matrix-based formulation:

Random.seed!(123)
sol_generic = solve(StateSpaceProblem(linear_f!!, linear_g!!, u0, (0, T), p;
    n_shocks = 1, n_obs = 2))

Random.seed!(123)
sol_linear = solve(LinearStateSpaceProblem(A, B, u0, (0, T); C))

sol_generic.u ≈ sol_linear.u
true

Example: Nonlinear Growth Model

StateSpaceProblem handles arbitrary nonlinear dynamics. Here is a discrete-time logistic growth model with process noise, demonstrating that the generic callback interface works for any transition function:

# Nonlinear transition: logistic growth with stochastic shocks
logistic_f!! = (x_next, x, w, p, t) -> begin
    x_next[1] = p.r * x[1] * (1.0 - x[1] / p.K) + p.sigma * w[1]
    return x_next
end

# Observation: noisy measurement of population
logistic_g!! = (y, x, p, t) -> begin
    y[1] = x[1]
    return y
end

p_logistic = (; r = 1.5, K = 100.0, sigma = 2.0)
u0_logistic = [50.0]

prob_logistic = StateSpaceProblem(logistic_f!!, logistic_g!!, u0_logistic, (0, 50), p_logistic;
    n_shocks = 1, n_obs = 1, syms = (:population,), obs_syms = (:measured_pop,))
sol_logistic = solve(prob_logistic)
retcode: Success
Interpolation: Piecewise constant interpolation
t: 0:1:50
u: 51-element Vector{Vector{Float64}}:
 [50.0]
 [38.686394776241954]
 [34.0432172151517]
 [33.52838035645266]
 [32.120485047510726]
 [34.07896492234528]
 [33.45887981517351]
 [33.80802025748804]
 [31.529268443880817]
 [33.39310235169254]
 ⋮
 [33.53785474891614]
 [31.604680583108284]
 [31.938622321540354]
 [35.05238446828499]
 [28.03340852937682]
 [31.59732736159761]
 [29.57416416233275]
 [37.30545063651327]
 [34.35392195629977]

Parametric Models and remake

The p argument holds all model parameters. When exploring different parameter values, use remake to create a new problem without reallocating everything.

new_u0 = [0.1, 0.2]
new_p = (; A = A * 0.99, B, C)

prob2 = remake(prob; u0 = new_u0, p = new_p)
sol2 = solve(prob2)
sol2.u[1]  # new initial condition
2-element Vector{Float64}:
 0.1
 0.2

The remake function preserves all keyword arguments (noise, observables, syms, etc.) from the original problem.

Symbolic Indexing

StateSpaceProblem supports the same symbolic indexing as the linear problem types. Pass syms for state variable names and obs_syms for observation names.

D = Diagonal([0.1, 0.1])
noise = sol.W  # reuse noise from earlier

prob_sym = StateSpaceProblem(linear_f!!, linear_g!!, u0, (0, T), p;
    n_shocks = 1, n_obs = 2,
    syms = (:capital, :productivity),
    obs_syms = (:output, :consumption),
    observables_noise = D, noise)
sol_sym = solve(prob_sym)

sol_sym[:capital]  # state time series by name
11-element Vector{Float64}:
  0.0
  0.0
 -0.011248100485393202
 -0.008389259888369121
 -0.06329244078472383
 -0.03577903516859477
 -0.026669779536420555
 -0.01946847868484066
 -0.07795024156135469
 -0.07193433548932342
 -0.065557615288592
sol_sym[:output]  # observation time series by name
11-element Vector{Float64}:
  0.04448442617226764
  0.16282558548263082
  0.31568334715595814
  0.18707061548885268
  0.2961660929826813
 -0.5082396610541259
 -0.07075018444674128
 -0.0258966136034754
  0.23748227725788143
  0.36760954852501876
  0.4689714881218356