Enforcing Physical Constraints via Universal Differential-Algebraic Equations
As shown in the stiff ODE tutorial, differential-algebraic equations (DAEs) can be used to impose physical constraints. One way to define a DAE is through an ODE with a singular mass matrix. For example, if we make Mu' = f(u) where the last row of M is all zeros, then we have a constraint defined by the right-hand side. Using NeuralODEMM, we can use this to define a neural ODE where the sum of all 3 terms must add to one. An example of this is as follows:
using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
DifferentialEquations, Plots
using Random
rng = Random.default_rng()
function f!(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
du[3] = y₁ + y₂ + y₃ - 1
return nothing
end
u₀ = [1.0, 0, 0]
M = [1.0 0 0
0 1.0 0
0 0 0]
tspan = (0.0, 1.0)
p = [0.04, 3e7, 1e4]
stiff_func = ODEFunction(f!, mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
Lux.Dense(64, 2))
pinit, st = Lux.setup(rng, nn_dudt2)
model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
tspan, M, Rodas5(autodiff = false), saveat = 0.1)
model_stiff_ndae(u₀, ComponentArray(pinit), st)
function predict_stiff_ndae(p)
return model_stiff_ndae(u₀, p, st)[1]
end
function loss_stiff_ndae(p)
pred = predict_stiff_ndae(p)
loss = sum(abs2, Array(sol_stiff) .- pred)
return loss, pred
end
# callback = function (p, l, pred) #callback function to observe training
# display(l)
# return false
# end
l1 = first(loss_stiff_ndae(ComponentArray(pinit)))
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)u: 386-element Vector{Float64}:
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
⋮
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaNStep-by-Step Description
Load Packages
using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
DifferentialEquations, Plots
using Random
rng = Random.default_rng()Random.TaskLocalRNG()Differential Equation
First, we define our differential equations as a highly stiff problem, which makes the fitting difficult.
function f!(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
du[3] = y₁ + y₂ + y₃ - 1
return nothing
endf! (generic function with 1 method)Parameters
u₀ = [1.0, 0, 0]
M = [1.0 0 0
0 1.0 0
0 0 0]
tspan = (0.0, 1.0)
p = [0.04, 3e7, 1e4]3-element Vector{Float64}:
0.04
3.0e7
10000.0u₀= Initial ConditionsM= Semi-explicit Mass Matrix (last row is the constraint equation and are therefore all zeros)tspan= Time span over which to evaluatep= parametersk1,k2andk3of the differential equation above
ODE Function, Problem and Solution
We define and solve our ODE problem to generate the “labeled” data which will be used to train our Neural Network.
stiff_func = ODEFunction(f!, mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)retcode: Success
Interpolation: 1st order linear
t: 11-element Vector{Float64}:
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
1.0
u: 11-element Vector{Vector{Float64}}:
[1.0, 0.0, 0.0]
[0.9960777474341889, 3.5804372328739174e-5, 0.003886448193482536]
[0.9923059457218133, 3.512303015079638e-5, 0.007658931248036001]
[0.9886739385487276, 3.4477160464978214e-5, 0.011291584290807323]
[0.9851721109941391, 3.386396553552364e-5, 0.01479402504032534]
[0.9817917747099651, 3.328089042275513e-5, 0.018174944399613487]
[0.9785250342445795, 3.2725768110280034e-5, 0.021442239987310215]
[0.9753647131269014, 3.2196529785412034e-5, 0.02460309034331384]
[0.9723042979019034, 3.169123899638582e-5, 0.027664010859099343]
[0.9693377993879712, 3.120829683456607e-5, 0.030630992315192573]
[0.966459738805013, 3.0746266110151764e-5, 0.033509514928876834]Because this is a DAE, we need to make sure to use a compatible solver. Rodas5 works well for this example.
Neural Network Layers
Next, we create our layers using Lux.Chain. We use this instead of Flux.Chain because it is more suited to SciML applications (similarly for Lux.Dense). The input to our network will be the initial conditions fed in as u₀.
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
Lux.Dense(64, 2))
pinit, st = Lux.setup(rng, nn_dudt2)
model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
tspan, M, Rodas5(autodiff = false), saveat = 0.1)
model_stiff_ndae(u₀, ComponentArray(pinit), st)(SciMLBase.ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.Rodas5{1, false, LinearSolve.DefaultLinearSolver, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing}}, DiffEqBase.Stats, Nothing}([[1.0, 0.0, 0.0], [1.0452634611088794, -0.004776114956398698, -0.04048734617047308], [1.0948048371877261, -0.009914790669844211, -0.08489004656506638], [1.1489811032551278, -0.015443335061148873, -0.13353776815452406], [1.2081656781870873, -0.02139046185636857, -0.18677521646106304], [1.2727450222828023, -0.02778611070809808, -0.2449589115287577], [1.3431148320985087, -0.034661218990736535, -0.30845361315625147], [1.4196737010897262, -0.04204728765787751, -0.37762641326226226], [1.5028163508288979, -0.049975854273937075, -0.4528404965518792], [1.5929257370505159, -0.05847779471383109, -0.5344479423396598], [1.690361097052936, -0.06758214430046436, -0.622778952752471]], nothing, nothing, [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [[[1.0, 0.0, 0.0]]], SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}((layer_1 = NamedTuple(), layer_2 = NamedTuple()), DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM())), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), [1.0, 0.0, 0.0], (0.0, 1.0), (layer_1 = (weight = Float32[-0.25666323 0.09931498 -0.19218454; 0.030896112 -0.035787977 0.13781366; … ; -0.14055529 0.18578464 0.038605917; 0.21711476 0.09057528 0.184177], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.07985205 -0.23182221 … -0.06021282 0.25547802; 0.24087206 -0.2222695 … 0.1169932 0.12401739], bias = Float32[0.0; 0.0;;])), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}(), SciMLBase.StandardODEProblem()), OrdinaryDiffEq.Rodas5{1, false, LinearSolve.DefaultLinearSolver, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}(LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization), OrdinaryDiffEq.DEFAULT_PRECS), OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing}}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}((layer_1 = NamedTuple(), layer_2 = NamedTuple()), DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM())), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), [[1.0, 0.0, 0.0], [1.0452634611088794, -0.004776114956398698, -0.04048734617047308], [1.0948048371877261, -0.009914790669844211, -0.08489004656506638], [1.1489811032551278, -0.015443335061148873, -0.13353776815452406], [1.2081656781870873, -0.02139046185636857, -0.18677521646106304], [1.2727450222828023, -0.02778611070809808, -0.2449589115287577], [1.3431148320985087, -0.034661218990736535, -0.30845361315625147], [1.4196737010897262, -0.04204728765787751, -0.37762641326226226], [1.5028163508288979, -0.049975854273937075, -0.4528404965518792], [1.5929257370505159, -0.05847779471383109, -0.5344479423396598], [1.690361097052936, -0.06758214430046436, -0.622778952752471]], [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [[[1.0, 0.0, 0.0]]], false, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing}(SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}((layer_1 = NamedTuple(), layer_2 = NamedTuple()), DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM())), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), [1.0, 0.0, 0.0], (layer_1 = (weight = Float32[-0.25666323 0.09931498 -0.19218454; 0.030896112 -0.035787977 0.13781366; … ; -0.14055529 0.18578464 0.038605917; 0.21711476 0.09057528 0.184177], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.07985205 -0.23182221 … -0.06021282 0.25547802; 0.24087206 -0.2222695 … 0.1169932 0.12401739], bias = Float32[0.0; 0.0;;]))), SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}(DiffEqFlux.var"#f#165"{NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}}((layer_1 = NamedTuple(), layer_2 = NamedTuple()), DiffEqFlux.var"#f#164#166"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}}, Nothing}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM())), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing), 0.7158296258901837, (layer_1 = (weight = Float32[-0.25666323 0.09931498 -0.19218454; 0.030896112 -0.035787977 0.13781366; … ; -0.14055529 0.18578464 0.038605917; 0.21711476 0.09057528 0.184177], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.07985205 -0.23182221 … -0.06021282 0.25547802; 0.24087206 -0.2222695 … 0.1169932 0.12401739], bias = Float32[0.0; 0.0;;]))), OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}(2.0, 3.040894194418781, 1.041747909077569, 2.576417536461461, 1.62208306077664, -0.9089668560264532, 2.760842080225597, 1.446624659844071, -0.3036980084553738, 0.2877498600325443, -14.09640773051259, 6.925207756232704, -41.47510893210728, 2.343771018586405, 24.13215229196062, -10.31323885133993, -21.04823117650003, -7.234992135176716, 32.22751541853323, -4.943732386540191, 19.44922031041879, -20.69865579590063, -8.816374604402768, 1.260436877740897, -0.7495647613787146, -46.22004352711257, -17.49534862857472, -289.6389582892057, 93.60855400400906, 318.3822534212147, 34.20013733472935, -14.1553540271769, 57.823356409884, 25.83362985412365, 1.408950972071624, -6.551835421242162, 42.57076742291101, -13.80770672017997, 93.98938432427124, 18.77919633714503, -31.5835918722337, -6.685968952921985, -5.810979938412932, 0.19, 0.19, -0.18230792253337147, -0.3192318321868749, 0.3449828624725343, -0.37741756439208984, 0.38, 0.3878509998321533, 0.483971893787384, 0.457047700881958, 27.354592673333357, -6.925207756232857, 26.40037733258859, 0.5635230501052979, -4.699151156849391, -1.6008677469422725, -1.5306074446748028, -1.3929872940716344, 44.19024239501722, 1.3677947663381929e-13, 202.93261852171622, -35.5669339789154, -181.91095152160645, 3.4116351403665033, 2.5793540257308067, 2.2435122582734066, -44.0988150021747, -5.755396159656812e-13, -181.26175034586677, 56.99302194811676, 183.21182741427398, -7.480257918273637, -5.792426076169686, -5.32503859794143), [2.5463949493e-313 3.8195924246e-313 4.03179200377e-313; 2.970794108e-313 4.03179200377e-313 5.09278989926e-313; 3.3951932663e-313 4.03179200377e-313 6.9408360326696e-310], LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}(Matrix{Float64}(undef, 0, 0), Int64[], 0), nothing)), false, 0, DiffEqBase.Stats
Number of function 1 evaluations: 122
Number of function 2 evaluations: 0
Number of W matrix evaluations: 10
Number of linear solves: 80
Number of Jacobians created: 10
Number of nonlinear solver iterations: 0
Number of nonlinear solver convergence failures: 0
Number of rootfind condition calls: 0
Number of accepted steps: 10
Number of rejected steps: 0, nothing, SciMLBase.ReturnCode.Success), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))Because this is a stiff problem, we have manually imposed that sum constraint via (u,p,t) -> [u[1] + u[2] + u[3] - 1], making the fitting easier.
Prediction Function
For simplicity, we define a wrapper function that only takes in the model's parameters to make predictions.
function predict_stiff_ndae(p)
return model_stiff_ndae(u₀, p, st)[1]
endpredict_stiff_ndae (generic function with 1 method)Train Parameters
Training our network requires a loss function, an optimizer, and a callback function to display the progress.
Loss
We first make our predictions based on the current parameters, then calculate the loss from these predictions. In this case, we use least squares as our loss.
function loss_stiff_ndae(p)
pred = predict_stiff_ndae(p)
loss = sum(abs2, sol_stiff .- pred)
return loss, pred
end
l1 = first(loss_stiff_ndae(ComponentArray(pinit)))3.105655105730176Notice that we are feeding the parameters of model_stiff_ndae to the loss_stiff_ndae function. model_stiff_node.p are the weights of our NN and is of size 386 (4 * 64 + 65 * 2) including the biases.
Optimizer
The optimizer is BFGS(see below).
Callback
The callback function displays the loss during training.
callback = function (p, l, pred) #callback function to observe training
display(l)
return false
end#3 (generic function with 1 method)Train
Finally, training with Optimization.solve by passing: loss function, model parameters, optimizer, callback and maximum iteration.
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)u: 386-element Vector{Float64}:
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
⋮
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN
NaN