Enforcing Physical Constraints via Universal Differential-Algebraic Equations

As shown in DiffEqDocs, differential-algebraic equations (DAEs) can be used to impose physical constraints. One way to define a DAE is through an ODE with a singular mass matrix. For example, if we make Mu' = f(u) where the last row of M is all zeros, then we have a constraint defined by the right-hand side. Using NeuralODEMM, we can use this to define a neural ODE where the sum of all 3 terms must add to one. An example of this is as follows:

using SciMLSensitivity
using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
    OrdinaryDiffEq, Plots

using Random
rng = Random.default_rng()

function f!(du, u, p, t)
    y₁, y₂, y₃ = u
    k₁, k₂, k₃ = p
    du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
    du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
    du[3] = y₁ + y₂ + y₃ - 1
    return nothing
end

u₀ = [1.0, 0, 0]
M = [1.0 0 0
    0 1.0 0
    0 0 0]

tspan = (0.0, 1.0)
p = [0.04, 3e7, 1e4]

stiff_func = ODEFunction(f!, mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)

nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
    Lux.Dense(64, 2))

pinit, st = Lux.setup(rng, nn_dudt2)

model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
    tspan, M, Rodas5(autodiff = false), saveat = 0.1)
model_stiff_ndae(u₀, ComponentArray(pinit), st)

function predict_stiff_ndae(p)
    return model_stiff_ndae(u₀, p, st)[1]
end

function loss_stiff_ndae(p)
    pred = predict_stiff_ndae(p)
    loss = sum(abs2, Array(sol_stiff) .- pred)
    return loss, pred
end

# callback = function (p, l, pred) #callback function to observe training
#   display(l)
#   return false
# end

l1 = first(loss_stiff_ndae(ComponentArray(pinit)))

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)
retcode: Failure
u: 386-element Vector{Float64}:
  0.2234731763601303
  0.0010877986205741763
  0.1282319873571396
  0.007766488939523697
  0.07420586049556732
 -0.2788427770137787
  0.27994680404663086
  0.04406513646245003
 -0.22862255573272705
  0.0763501301407814
  ⋮
 -0.1377258449792862
 -0.08366312086582184
 -0.13129134476184845
  0.02216159924864769
 -0.053207073360681534
 -0.2486088126897812
 -0.275028795003891
  0.0
  0.0

Step-by-Step Description

Load Packages

using SciMLSensitivity
using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
    OrdinaryDiffEq, Plots

using Random
rng = Random.default_rng()
Random.TaskLocalRNG()

Differential Equation

First, we define our differential equations as a highly stiff problem, which makes the fitting difficult.

function f!(du, u, p, t)
    y₁, y₂, y₃ = u
    k₁, k₂, k₃ = p
    du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
    du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
    du[3] = y₁ + y₂ + y₃ - 1
    return nothing
end
f! (generic function with 1 method)

Parameters

u₀ = [1.0, 0, 0]

M = [1.0 0 0
    0 1.0 0
    0 0 0]

tspan = (0.0, 1.0)

p = [0.04, 3e7, 1e4]
3-element Vector{Float64}:
     0.04
     3.0e7
 10000.0
  • u₀ = Initial Conditions
  • M = Semi-explicit Mass Matrix (last row is the constraint equation and are therefore all zeros)
  • tspan = Time span over which to evaluate
  • p = parameters k1, k2 and k3 of the differential equation above

ODE Function, Problem and Solution

We define and solve our ODE problem to generate the “labeled” data which will be used to train our Neural Network.

stiff_func = ODEFunction(f!, mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)
retcode: Success
Interpolation: 1st order linear
t: 11-element Vector{Float64}:
 0.0
 0.1
 0.2
 0.3
 0.4
 0.5
 0.6
 0.7
 0.8
 0.9
 1.0
u: 11-element Vector{Vector{Float64}}:
 [1.0, 0.0, 0.0]
 [0.9960777474341889, 3.5804372328739174e-5, 0.003886448193482536]
 [0.9923059457218133, 3.512303015079638e-5, 0.007658931248036001]
 [0.9886739385487276, 3.4477160464978214e-5, 0.011291584290807323]
 [0.9851721109941391, 3.386396553552364e-5, 0.01479402504032534]
 [0.9817917747099651, 3.328089042275513e-5, 0.018174944399613487]
 [0.9785250342445795, 3.2725768110280034e-5, 0.021442239987310215]
 [0.9753647131269014, 3.2196529785412034e-5, 0.02460309034331384]
 [0.9723042979019034, 3.169123899638582e-5, 0.027664010859099343]
 [0.9693377993879712, 3.120829683456607e-5, 0.030630992315192573]
 [0.966459738805013, 3.0746266110151764e-5, 0.033509514928876834]

Because this is a DAE, we need to make sure to use a compatible solver. Rodas5 works well for this example.

Neural Network Layers

Next, we create our layers using Lux.Chain. We use this instead of Flux.Chain because it is more suited to SciML applications (similarly for Lux.Dense). The input to our network will be the initial conditions fed in as u₀.

nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
    Lux.Dense(64, 2))

pinit, st = Lux.setup(rng, nn_dudt2)

model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
    tspan, M, Rodas5(autodiff = false), saveat = 0.1)
model_stiff_ndae(u₀, ComponentArray(pinit), st)
(SciMLBase.ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.Rodas5{1, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing}, BitVector}, SciMLBase.DEStats, Nothing}([[1.0, 0.0, 0.0], [1.0335651290523673, 0.03533774770305834, -0.06890287675612541], [1.0659811413242448, 0.07439233663296213, -0.1403734779737684], [1.0970115120792479, 0.11724869219366867, -0.21426020423658115], [1.1264104022363028, 0.16397237110531543, -0.2903827733090034], [1.1539240844568472, 0.2146074273007541, -0.36853151183210553], [1.1792927695911006, 0.2691742386950685, -0.44846700815562596], [1.2022531551828424, 0.3276678254243086, -0.5299209806801537], [1.222539978906791, 0.39005717256131145, -0.6125971513714854], [1.2398889275534417, 0.45628393676827905, -0.6961728644148065], [1.2540396363438882, 0.5262632568261779, -0.7803028931700663]], nothing, nothing, [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [[[1.0, 0.0, 0.0]]], SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}(NeuralODEMM(), Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}(Chain(), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple()))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), [1.0, 0.0, 0.0], (0.0, 1.0), (layer_1 = (weight = Float32[-0.22851282 0.28590652 -0.23959994; 0.118569046 -0.11944113 -0.29642808; … ; -0.21028031 0.14263707 -0.15413335; 0.06913819 0.2765988 -0.04022657], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.2288007 -0.0006805439 … 0.28057346 -0.17699033; -0.12116632 -0.12562795 … 0.26228684 -0.18568242], bias = Float32[0.0; 0.0;;])), Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), Rodas5(; linsolve = nothing, precs = DEFAULT_PRECS,), OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing}, BitVector}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}(NeuralODEMM(), Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}(Chain(), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple()))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), [[1.0, 0.0, 0.0], [1.0335651290523673, 0.03533774770305834, -0.06890287675612541], [1.0659811413242448, 0.07439233663296213, -0.1403734779737684], [1.0970115120792479, 0.11724869219366867, -0.21426020423658115], [1.1264104022363028, 0.16397237110531543, -0.2903827733090034], [1.1539240844568472, 0.2146074273007541, -0.36853151183210553], [1.1792927695911006, 0.2691742386950685, -0.44846700815562596], [1.2022531551828424, 0.3276678254243086, -0.5299209806801537], [1.222539978906791, 0.39005717256131145, -0.6125971513714854], [1.2398889275534417, 0.45628393676827905, -0.6961728644148065], [1.2540396363438882, 0.5262632568261779, -0.7803028931700663]], [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [[[1.0, 0.0, 0.0]]], nothing, false, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing}(SciMLBase.TimeDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}(NeuralODEMM(), Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}(Chain(), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple()))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), [1.0, 0.0, 0.0], (layer_1 = (weight = Float32[-0.22851282 0.28590652 -0.23959994; 0.118569046 -0.11944113 -0.29642808; … ; -0.21028031 0.14263707 -0.15413335; 0.06913819 0.2765988 -0.04022657], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.2288007 -0.0006805439 … 0.28057346 -0.17699033; -0.12116632 -0.12562795 … 0.26228684 -0.18568242], bias = Float32[0.0; 0.0;;]))), SciMLBase.UDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#44"{DiffEqFlux.NeuralODEMM{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}(NeuralODEMM(), Lux.Experimental.StatefulLuxLayer{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}(Chain(), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple()))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), 0.6855933457172683, (layer_1 = (weight = Float32[-0.22851282 0.28590652 -0.23959994; 0.118569046 -0.11944113 -0.29642808; … ; -0.21028031 0.14263707 -0.15413335; 0.06913819 0.2765988 -0.04022657], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.2288007 -0.0006805439 … 0.28057346 -0.17699033; -0.12116632 -0.12562795 … 0.26228684 -0.18568242], bias = Float32[0.0; 0.0;;]))), OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}(2.0, 3.040894194418781, 1.041747909077569, 2.576417536461461, 1.62208306077664, -0.9089668560264532, 2.760842080225597, 1.446624659844071, -0.3036980084553738, 0.2877498600325443, -14.09640773051259, 6.925207756232704, -41.47510893210728, 2.343771018586405, 24.13215229196062, -10.31323885133993, -21.04823117650003, -7.234992135176716, 32.22751541853323, -4.943732386540191, 19.44922031041879, -20.69865579590063, -8.816374604402768, 1.260436877740897, -0.7495647613787146, -46.22004352711257, -17.49534862857472, -289.6389582892057, 93.60855400400906, 318.3822534212147, 34.20013733472935, -14.1553540271769, 57.823356409884, 25.83362985412365, 1.408950972071624, -6.551835421242162, 42.57076742291101, -13.80770672017997, 93.98938432427124, 18.77919633714503, -31.5835918722337, -6.685968952921985, -5.810979938412932, 0.19, 0.19, -0.18230792253337147, -0.3192318321868749, 0.3449828624725343, -0.37741756439208984, 0.38, 0.3878509998321533, 0.483971893787384, 0.457047700881958, 27.354592673333357, -6.925207756232857, 26.40037733258859, 0.5635230501052979, -4.699151156849391, -1.6008677469422725, -1.5306074446748028, -1.3929872940716344, 44.19024239501722, 1.3677947663381929e-13, 202.93261852171622, -35.5669339789154, -181.91095152160645, 3.4116351403665033, 2.5793540257308067, 2.2435122582734066, -44.0988150021747, -5.755396159656812e-13, -181.26175034586677, 56.99302194811676, 183.21182741427398, -7.480257918273637, -5.792426076169686, -5.32503859794143), [6.94231954240387e-310 6.942345295963e-310 6.942319542411e-310; 6.942334175369e-310 6.94231954240703e-310 6.94231954241177e-310; 6.9423354596551e-310 6.94234529597723e-310 0.0], LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}(Matrix{Float64}(undef, 0, 0), Int64[], 0), nothing), Bool[1, 1, 0], false), false, 0, SciMLBase.DEStats(110, 0, 9, 72, 9, 0, 0, 0, 9, 0, 0.0), nothing, SciMLBase.ReturnCode.Success), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Because this is a stiff problem, we have manually imposed that sum constraint via (u,p,t) -> [u[1] + u[2] + u[3] - 1], making the fitting easier.

Prediction Function

For simplicity, we define a wrapper function that only takes in the model's parameters to make predictions.

function predict_stiff_ndae(p)
    return model_stiff_ndae(u₀, p, st)[1]
end
predict_stiff_ndae (generic function with 1 method)

Train Parameters

Training our network requires a loss function, an optimizer, and a callback function to display the progress.

Loss

We first make our predictions based on the current parameters, then calculate the loss from these predictions. In this case, we use least squares as our loss.

function loss_stiff_ndae(p)
    pred = predict_stiff_ndae(p)
    loss = sum(abs2, sol_stiff .- pred)
    return loss, pred
end

l1 = first(loss_stiff_ndae(ComponentArray(pinit)))
3.7384145813311727

Notice that we are feeding the parameters of model_stiff_ndae to the loss_stiff_ndae function. model_stiff_node.p are the weights of our NN and is of size 386 (4 * 64 + 65 * 2) including the biases.

Optimizer

The optimizer is BFGS(see below).

Callback

The callback function displays the loss during training.

callback = function (p, l, pred) #callback function to observe training
    display(l)
    return false
end
#3 (generic function with 1 method)

Train

Finally, training with Optimization.solve by passing: loss function, model parameters, optimizer, callback and maximum iteration.

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)
retcode: Failure
u: 386-element Vector{Float64}:
 -0.22851282358169556
  0.11856904625892639
 -0.0390140600502491
  0.20039944350719452
  0.23495440185070038
  0.2336994707584381
 -0.08821128308773041
  0.1383007913827896
  0.22397884726524353
  0.03180921822786331
  ⋮
 -0.2328682541847229
 -0.10183526575565338
  0.1404077708721161
  0.28057345747947693
  0.26228684186935425
 -0.1769903302192688
 -0.18568241596221924
  0.0
  0.0