Enforcing Physical Constraints via Universal Differential-Algebraic Equations

As shown in the stiff ODE tutorial, differential-algebraic equations (DAEs) can be used to impose physical constraints. One way to define a DAE is through an ODE with a singular mass matrix. For example, if we make Mu' = f(u) where the last row of M is all zeros, then we have a constraint defined by the right-hand side. Using NeuralODEMM, we can use this to define a neural ODE where the sum of all 3 terms must add to one. An example of this is as follows:

using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
    DifferentialEquations, Plots

using Random
rng = Random.default_rng()

function f!(du, u, p, t)
    y₁, y₂, y₃ = u
    k₁, k₂, k₃ = p
    du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
    du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
    du[3] = y₁ + y₂ + y₃ - 1
    return nothing
end

u₀ = [1.0, 0, 0]
M = [1.0 0 0
    0 1.0 0
    0 0 0]

tspan = (0.0, 1.0)
p = [0.04, 3e7, 1e4]

stiff_func = ODEFunction(f!, mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)

nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
    Lux.Dense(64, 2))

pinit, st = Lux.setup(rng, nn_dudt2)

model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
    tspan, M, Rodas5(autodiff = false), saveat = 0.1)
model_stiff_ndae(u₀, ComponentArray(pinit), st)

function predict_stiff_ndae(p)
    return model_stiff_ndae(u₀, p, st)[1]
end

function loss_stiff_ndae(p)
    pred = predict_stiff_ndae(p)
    loss = sum(abs2, Array(sol_stiff) .- pred)
    return loss, pred
end

# callback = function (p, l, pred) #callback function to observe training
#   display(l)
#   return false
# end

l1 = first(loss_stiff_ndae(ComponentArray(pinit)))

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)
u: 386-element Vector{Float64}:
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
   ⋮
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN

Step-by-Step Description

Load Packages

using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
    DifferentialEquations, Plots

using Random
rng = Random.default_rng()
Random.TaskLocalRNG()

Differential Equation

First, we define our differential equations as a highly stiff problem, which makes the fitting difficult.

function f!(du, u, p, t)
    y₁, y₂, y₃ = u
    k₁, k₂, k₃ = p
    du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
    du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
    du[3] = y₁ + y₂ + y₃ - 1
    return nothing
end
f! (generic function with 1 method)

Parameters

u₀ = [1.0, 0, 0]

M = [1.0 0 0
    0 1.0 0
    0 0 0]

tspan = (0.0, 1.0)

p = [0.04, 3e7, 1e4]
3-element Vector{Float64}:
     0.04
     3.0e7
 10000.0
  • u₀ = Initial Conditions
  • M = Semi-explicit Mass Matrix (last row is the constraint equation and are therefore all zeros)
  • tspan = Time span over which to evaluate
  • p = parameters k1, k2 and k3 of the differential equation above

ODE Function, Problem and Solution

We define and solve our ODE problem to generate the “labeled” data which will be used to train our Neural Network.

stiff_func = ODEFunction(f!, mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)
retcode: Success
Interpolation: 1st order linear
t: 11-element Vector{Float64}:
 0.0
 0.1
 0.2
 0.3
 0.4
 0.5
 0.6
 0.7
 0.8
 0.9
 1.0
u: 11-element Vector{Vector{Float64}}:
 [1.0, 0.0, 0.0]
 [0.9960777474341889, 3.5804372328739174e-5, 0.003886448193482536]
 [0.9923059457218133, 3.512303015079638e-5, 0.007658931248036001]
 [0.9886739385487276, 3.4477160464978214e-5, 0.011291584290807323]
 [0.9851721109941391, 3.386396553552364e-5, 0.01479402504032534]
 [0.9817917747099651, 3.328089042275513e-5, 0.018174944399613487]
 [0.9785250342445795, 3.2725768110280034e-5, 0.021442239987310215]
 [0.9753647131269014, 3.2196529785412034e-5, 0.02460309034331384]
 [0.9723042979019034, 3.169123899638582e-5, 0.027664010859099343]
 [0.9693377993879712, 3.120829683456607e-5, 0.030630992315192573]
 [0.966459738805013, 3.0746266110151764e-5, 0.033509514928876834]

Because this is a DAE, we need to make sure to use a compatible solver. Rodas5 works well for this example.

Neural Network Layers

Next, we create our layers using Lux.Chain. We use this instead of Flux.Chain because it is more suited to SciML applications (similarly for Lux.Dense). The input to our network will be the initial conditions fed in as u₀.

nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
    Lux.Dense(64, 2))

pinit, st = Lux.setup(rng, nn_dudt2)

model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
    tspan, M, Rodas5(autodiff = false), saveat = 0.1)
model_stiff_ndae(u₀, ComponentArray(pinit), st)
(SciMLBase.ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.Rodas5{1, false, LinearSolve.DefaultLinearSolver, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing}}, DiffEqBase.Stats, Nothing}([[1.0, 0.0, 0.0], [1.0452634611088794, -0.004776114956398698, -0.04048734617047308], [1.0948048371877261, -0.009914790669844211, -0.08489004656506638], [1.1489811032551278, -0.015443335061148873, -0.13353776815452406], [1.2081656781870873, -0.02139046185636857, -0.18677521646106304], [1.2727450222828023, -0.02778611070809808, -0.2449589115287577], [1.3431148320985087, -0.034661218990736535, -0.30845361315625147], [1.4196737010897262, -0.04204728765787751, -0.37762641326226226], [1.5028163508288979, -0.049975854273937075, -0.4528404965518792], [1.5929257370505159, -0.05847779471383109, -0.5344479423396598], [1.690361097052936, -0.06758214430046436, -0.622778952752471]], nothing, nothing, [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [[[1.0, 0.0, 0.0]]], SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}((layer_1 = NamedTuple(), layer_2 = NamedTuple()), DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM())), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), [1.0, 0.0, 0.0], (0.0, 1.0), (layer_1 = (weight = Float32[-0.25666323 0.09931498 -0.19218454; 0.030896112 -0.035787977 0.13781366; … ; -0.14055529 0.18578464 0.038605917; 0.21711476 0.09057528 0.184177], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.07985205 -0.23182221 … -0.06021282 0.25547802; 0.24087206 -0.2222695 … 0.1169932 0.12401739], bias = Float32[0.0; 0.0;;])), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}(), SciMLBase.StandardODEProblem()), OrdinaryDiffEq.Rodas5{1, false, LinearSolve.DefaultLinearSolver, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}(LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization), OrdinaryDiffEq.DEFAULT_PRECS), OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing}}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}((layer_1 = NamedTuple(), layer_2 = NamedTuple()), DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM())), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), [[1.0, 0.0, 0.0], [1.0452634611088794, -0.004776114956398698, -0.04048734617047308], [1.0948048371877261, -0.009914790669844211, -0.08489004656506638], [1.1489811032551278, -0.015443335061148873, -0.13353776815452406], [1.2081656781870873, -0.02139046185636857, -0.18677521646106304], [1.2727450222828023, -0.02778611070809808, -0.2449589115287577], [1.3431148320985087, -0.034661218990736535, -0.30845361315625147], [1.4196737010897262, -0.04204728765787751, -0.37762641326226226], [1.5028163508288979, -0.049975854273937075, -0.4528404965518792], [1.5929257370505159, -0.05847779471383109, -0.5344479423396598], [1.690361097052936, -0.06758214430046436, -0.622778952752471]], [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [[[1.0, 0.0, 0.0]]], false, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing}(SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}((layer_1 = NamedTuple(), layer_2 = NamedTuple()), DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM())), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), [1.0, 0.0, 0.0], (layer_1 = (weight = Float32[-0.25666323 0.09931498 -0.19218454; 0.030896112 -0.035787977 0.13781366; … ; -0.14055529 0.18578464 0.038605917; 0.21711476 0.09057528 0.184177], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.07985205 -0.23182221 … -0.06021282 0.25547802; 0.24087206 -0.2222695 … 0.1169932 0.12401739], bias = Float32[0.0; 0.0;;]))), SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}((layer_1 = NamedTuple(), layer_2 = NamedTuple()), DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM())), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), 0.7158296258901837, (layer_1 = (weight = Float32[-0.25666323 0.09931498 -0.19218454; 0.030896112 -0.035787977 0.13781366; … ; -0.14055529 0.18578464 0.038605917; 0.21711476 0.09057528 0.184177], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.07985205 -0.23182221 … -0.06021282 0.25547802; 0.24087206 -0.2222695 … 0.1169932 0.12401739], bias = Float32[0.0; 0.0;;]))), OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}(2.0, 3.040894194418781, 1.041747909077569, 2.576417536461461, 1.62208306077664, -0.9089668560264532, 2.760842080225597, 1.446624659844071, -0.3036980084553738, 0.2877498600325443, -14.09640773051259, 6.925207756232704, -41.47510893210728, 2.343771018586405, 24.13215229196062, -10.31323885133993, -21.04823117650003, -7.234992135176716, 32.22751541853323, -4.943732386540191, 19.44922031041879, -20.69865579590063, -8.816374604402768, 1.260436877740897, -0.7495647613787146, -46.22004352711257, -17.49534862857472, -289.6389582892057, 93.60855400400906, 318.3822534212147, 34.20013733472935, -14.1553540271769, 57.823356409884, 25.83362985412365, 1.408950972071624, -6.551835421242162, 42.57076742291101, -13.80770672017997, 93.98938432427124, 18.77919633714503, -31.5835918722337, -6.685968952921985, -5.810979938412932, 0.19, 0.19, -0.18230792253337147, -0.3192318321868749, 0.3449828624725343, -0.37741756439208984, 0.38, 0.3878509998321533, 0.483971893787384, 0.457047700881958, 27.354592673333357, -6.925207756232857, 26.40037733258859, 0.5635230501052979, -4.699151156849391, -1.6008677469422725, -1.5306074446748028, -1.3929872940716344, 44.19024239501722, 1.3677947663381929e-13, 202.93261852171622, -35.5669339789154, -181.91095152160645, 3.4116351403665033, 2.5793540257308067, 2.2435122582734066, -44.0988150021747, -5.755396159656812e-13, -181.26175034586677, 56.99302194811676, 183.21182741427398, -7.480257918273637, -5.792426076169686, -5.32503859794143), [2.5463949493e-313 3.8195924246e-313 4.03179200377e-313; 2.970794108e-313 4.03179200377e-313 5.09278989926e-313; 3.3951932663e-313 4.03179200377e-313 6.9408360326696e-310], LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}(Matrix{Float64}(undef, 0, 0), Int64[], 0), nothing)), false, 0, DiffEqBase.Stats
Number of function 1 evaluations:                  122
Number of function 2 evaluations:                  0
Number of W matrix evaluations:                    10
Number of linear solves:                           80
Number of Jacobians created:                       10
Number of nonlinear solver iterations:             0
Number of nonlinear solver convergence failures:   0
Number of rootfind condition calls:                0
Number of accepted steps:                          10
Number of rejected steps:                          0, nothing, SciMLBase.ReturnCode.Success), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Because this is a stiff problem, we have manually imposed that sum constraint via (u,p,t) -> [u[1] + u[2] + u[3] - 1], making the fitting easier.

Prediction Function

For simplicity, we define a wrapper function that only takes in the model's parameters to make predictions.

function predict_stiff_ndae(p)
    return model_stiff_ndae(u₀, p, st)[1]
end
predict_stiff_ndae (generic function with 1 method)

Train Parameters

Training our network requires a loss function, an optimizer, and a callback function to display the progress.

Loss

We first make our predictions based on the current parameters, then calculate the loss from these predictions. In this case, we use least squares as our loss.

function loss_stiff_ndae(p)
    pred = predict_stiff_ndae(p)
    loss = sum(abs2, sol_stiff .- pred)
    return loss, pred
end

l1 = first(loss_stiff_ndae(ComponentArray(pinit)))
3.105655105730176

Notice that we are feeding the parameters of model_stiff_ndae to the loss_stiff_ndae function. model_stiff_node.p are the weights of our NN and is of size 386 (4 * 64 + 65 * 2) including the biases.

Optimizer

The optimizer is BFGS(see below).

Callback

The callback function displays the loss during training.

callback = function (p, l, pred) #callback function to observe training
    display(l)
    return false
end
#3 (generic function with 1 method)

Train

Finally, training with Optimization.solve by passing: loss function, model parameters, optimizer, callback and maximum iteration.

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)
u: 386-element Vector{Float64}:
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
   ⋮
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN