Enforcing Physical Constraints via Universal Differential-Algebraic Equations

As shown in DiffEqDocs, differential-algebraic equations (DAEs) can be used to impose physical constraints. One way to define a DAE is through an ODE with a singular mass matrix. For example, if we make Mu' = f(u) where the last row of M is all zeros, then we have a constraint defined by the right-hand side. Using NeuralODEMM, we can use this to define a neural ODE where the sum of all 3 terms must add to one. An example of this is as follows:

using DiffEqFlux
using Lux, ComponentArrays, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Plots

using Random
rng = Random.default_rng()

function f!(du, u, p, t)
    y₁, y₂, y₃ = u
    k₁, k₂, k₃ = p
    du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
    du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
    du[3] = y₁ + y₂ + y₃ - 1
    return nothing
end

u₀ = [1.0, 0, 0]
M = [1.0 0 0
     0 1.0 0
     0 0 0]

tspan = (0.0, 1.0)
p = [0.04, 3e7, 1e4]

stiff_func = ODEFunction(f!; mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(); saveat = 0.1)

nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh), Lux.Dense(64, 2))

pinit, st = Lux.setup(rng, nn_dudt2)

model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
    tspan, M, Rodas5(; autodiff = false); saveat = 0.1)

function predict_stiff_ndae(p)
    return model_stiff_ndae(u₀, p, st)[1]
end

function loss_stiff_ndae(p)
    pred = predict_stiff_ndae(p)
    loss = sum(abs2, Array(sol_stiff) .- pred)
    return loss
end

# callback = function (state, l, pred) #callback function to observe training
#   display(l)
#   return false
# end

l1 = first(loss_stiff_ndae(ComponentArray(pinit)))

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, OptimizationOptimJL.BFGS(); maxiters = 100)
retcode: Failure
u: ComponentVector{Float32}(layer_1 = (weight = Float32[1.6589212 -0.51684415 -0.11542491; -0.56002945 -0.77483946 0.98897135; … ; 0.32362738 -0.8294616 -1.2255882; -0.3018981 -1.0740048 -0.53662354], bias = Float32[-0.11863555, -0.34261033, 0.19099298, 0.13703088, 0.47875524, -0.3840804, 0.5107739, 0.1801709, 0.03740601, -0.1946894  …  -0.50942844, -0.28227004, -0.32945663, -0.22611655, 0.3987328, 0.0575539, 0.5053546, 0.07841898, -0.55574286, 0.44226784]), layer_2 = (weight = Float32[0.07540194 0.10634061 … -0.27560532 0.10994727; 0.14830337 -0.09621674 … 0.12354869 -0.060180508], bias = Float32[-0.03774584, -0.08045015]))

Step-by-Step Description

Load Packages

using DiffEqFlux
using Lux, ComponentArrays, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Plots

using Random
rng = Random.default_rng()
Random.TaskLocalRNG()

Differential Equation

First, we define our differential equations as a highly stiff problem, which makes the fitting difficult.

function f!(du, u, p, t)
    y₁, y₂, y₃ = u
    k₁, k₂, k₃ = p
    du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
    du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
    du[3] = y₁ + y₂ + y₃ - 1
    return nothing
end
f! (generic function with 1 method)

Parameters

u₀ = [1.0, 0, 0]

M = [1.0 0 0
     0 1.0 0
     0 0 0]

tspan = (0.0, 1.0)

p = [0.04, 3e7, 1e4]
3-element Vector{Float64}:
     0.04
     3.0e7
 10000.0
  • u₀ = Initial Conditions
  • M = Semi-explicit Mass Matrix (last row is the constraint equation and are therefore all zeros)
  • tspan = Time span over which to evaluate
  • p = parameters k1, k2 and k3 of the differential equation above

ODE Function, Problem and Solution

We define and solve our ODE problem to generate the “labeled” data which will be used to train our Neural Network.

stiff_func = ODEFunction(f!; mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(); saveat = 0.1)
retcode: Success
Interpolation: 1st order linear
t: 11-element Vector{Float64}:
 0.0
 0.1
 0.2
 0.3
 0.4
 0.5
 0.6
 0.7
 0.8
 0.9
 1.0
u: 11-element Vector{Vector{Float64}}:
 [1.0, 0.0, 0.0]
 [0.9960777474341874, 3.580437232874095e-5, 0.003886448193483852]
 [0.9923059457218126, 3.5123030150794895e-5, 0.007658931248036548]
 [0.9886739385487181, 3.44771604650164e-5, 0.011291584290816925]
 [0.9851721109941468, 3.3863965535471e-5, 0.01479402504031772]
 [0.981791774709964, 3.328089042278384e-5, 0.01817494439961413]
 [0.9785250342445555, 3.2725768110357636e-5, 0.02144223998733427]
 [0.9753647131268983, 3.21965297854234e-5, 0.024603090343316574]
 [0.972304297901924, 3.1691238996228466e-5, 0.027664010859079317]
 [0.9693377993879673, 3.12082968346845e-5, 0.030630992315197364]
 [0.9664597388050115, 3.07462661101518e-5, 0.03350951492887844]

Because this is a DAE, we need to make sure to use a compatible solver. Rodas5 works well for this example.

Neural Network Layers

Next, we create our layers using Lux.Chain. We use this instead of Flux.Chain because it is more suited to SciML applications (similarly for Lux.Dense). The input to our network will be the initial conditions fed in as u₀.

nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh), Lux.Dense(64, 2))

pinit, st = Lux.setup(rng, nn_dudt2)

model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
    tspan, M, Rodas5(; autodiff = false); saveat = 0.1)
model_stiff_ndae(u₀, ComponentArray(pinit), st)
(SciMLBase.ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3))), bias = ViewAxis(193:256, Shaped1DAxis((64,))))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64))), bias = ViewAxis(129:130, Shaped1DAxis((2,))))))}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}, OrdinaryDiffEqCore.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEqRosenbrock.RosenbrockCombinedConstantCache{SciMLBase.TimeDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3))), bias = ViewAxis(193:256, Shaped1DAxis((64,))))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64))), bias = ViewAxis(129:130, Shaped1DAxis((2,))))))}}}}, SciMLBase.UDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3))), bias = ViewAxis(193:256, Shaped1DAxis((64,))))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64))), bias = ViewAxis(129:130, Shaped1DAxis((2,))))))}}}}, OrdinaryDiffEqRosenbrock.RodasTableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}}, BitVector}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}([[1.0, 0.0, 0.0], [0.8994092394017354, 0.05870707748885017, 0.041883683109414066], [0.7968240622151814, 0.1040113728091391, 0.0991645649756796], [0.692617381383233, 0.13315512887585965, 0.1742274897409082], [0.5883937220371189, 0.1433051671786624, 0.2683011107842187], [0.48769339868096445, 0.13222793635894176, 0.3800786649600938], [0.3958305237853755, 0.0994011623163984, 0.5047683138982259], [0.3182248534264028, 0.046836661063995144, 0.634938485509602], [0.25817621627169646, -0.021308912485382574, 0.7631326962136856], [0.21627601572938707, -0.10026264760497983, 0.8839866318755928], [0.1912619482867357, -0.18585011024869832, 0.9945881619619625]], nothing, nothing, [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [[[1.0, 0.0, 0.0]]], nothing, SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3))), bias = ViewAxis(193:256, Shaped1DAxis((64,))))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64))), bias = ViewAxis(129:130, Shaped1DAxis((2,))))))}}}, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}(NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}(Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(3 => 64, tanh), layer_2 = Dense(64 => 2)), nothing), Main.var"#1#2"(), (0.0, 1.0), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], (OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, AutoFiniteDiff()),), Base.Pairs(:saveat => 0.1)), StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}(Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(3 => 64, tanh), layer_2 = Dense(64 => 2)), nothing), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple()), nothing, static(true))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), [1.0, 0.0, 0.0], (0.0, 1.0), (layer_1 = (weight = Float32[-1.1705778 -1.6588916 0.32729805; -0.28424123 -1.1089251 -1.4212596; … ; -1.6514931 0.7060482 -0.2694267; -1.0793198 -1.4750975 -1.3796407], bias = Float32[-0.29647824, -0.16483429, 0.45835268, 0.56277174, 0.09751839, -0.054524124, 0.34757167, -0.27479675, -0.014247845, 0.20296796  …  -0.28666598, -0.54918426, 0.22267614, -0.32032785, -0.14613543, -0.14987265, 0.5316119, -0.18366171, -0.2584572, -0.11741771]), layer_2 = (weight = Float32[0.012355092 -0.06454273 … -0.042979304 0.19365568; -0.09335904 0.09502378 … 0.16607136 0.03768091], bias = Float32[-0.104326874, 0.11542177])), Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}(), SciMLBase.StandardODEProblem()), OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, AutoFiniteDiff()), OrdinaryDiffEqCore.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEqRosenbrock.RosenbrockCombinedConstantCache{SciMLBase.TimeDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3))), bias = ViewAxis(193:256, Shaped1DAxis((64,))))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64))), bias = ViewAxis(129:130, Shaped1DAxis((2,))))))}}}}, SciMLBase.UDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3))), bias = ViewAxis(193:256, Shaped1DAxis((64,))))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64))), bias = ViewAxis(129:130, Shaped1DAxis((2,))))))}}}}, OrdinaryDiffEqRosenbrock.RodasTableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}}, BitVector}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}(NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}(Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(3 => 64, tanh), layer_2 = Dense(64 => 2)), nothing), Main.var"#1#2"(), (0.0, 1.0), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], (OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, AutoFiniteDiff()),), Base.Pairs(:saveat => 0.1)), StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}(Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(3 => 64, tanh), layer_2 = Dense(64 => 2)), nothing), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple()), nothing, static(true))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), [[1.0, 0.0, 0.0], [0.8994092394017354, 0.05870707748885017, 0.041883683109414066], [0.7968240622151814, 0.1040113728091391, 0.0991645649756796], [0.692617381383233, 0.13315512887585965, 0.1742274897409082], [0.5883937220371189, 0.1433051671786624, 0.2683011107842187], [0.48769339868096445, 0.13222793635894176, 0.3800786649600938], [0.3958305237853755, 0.0994011623163984, 0.5047683138982259], [0.3182248534264028, 0.046836661063995144, 0.634938485509602], [0.25817621627169646, -0.021308912485382574, 0.7631326962136856], [0.21627601572938707, -0.10026264760497983, 0.8839866318755928], [0.1912619482867357, -0.18585011024869832, 0.9945881619619625]], [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [[[1.0, 0.0, 0.0]]], nothing, false, OrdinaryDiffEqRosenbrock.RosenbrockCombinedConstantCache{SciMLBase.TimeDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3))), bias = ViewAxis(193:256, Shaped1DAxis((64,))))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64))), bias = ViewAxis(129:130, Shaped1DAxis((2,))))))}}}}, SciMLBase.UDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3))), bias = ViewAxis(193:256, Shaped1DAxis((64,))))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64))), bias = ViewAxis(129:130, Shaped1DAxis((2,))))))}}}}, OrdinaryDiffEqRosenbrock.RodasTableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}, Nothing, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}}(SciMLBase.TimeDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3))), bias = ViewAxis(193:256, Shaped1DAxis((64,))))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64))), bias = ViewAxis(129:130, Shaped1DAxis((2,))))))}}}}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}(NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}(Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(3 => 64, tanh), layer_2 = Dense(64 => 2)), nothing), Main.var"#1#2"(), (0.0, 1.0), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], (OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, AutoFiniteDiff()),), Base.Pairs(:saveat => 0.1)), StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}(Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(3 => 64, tanh), layer_2 = Dense(64 => 2)), nothing), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple()), nothing, static(true))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), [0.22019792174655076, -0.0907857881685706, 0.8705878664220199], (layer_1 = (weight = Float32[-1.1705778 -1.6588916 0.32729805; -0.28424123 -1.1089251 -1.4212596; … ; -1.6514931 0.7060482 -0.2694267; -1.0793198 -1.4750975 -1.3796407], bias = Float32[-0.29647824, -0.16483429, 0.45835268, 0.56277174, 0.09751839, -0.054524124, 0.34757167, -0.27479675, -0.014247845, 0.20296796  …  -0.28666598, -0.54918426, 0.22267614, -0.32032785, -0.14613543, -0.14987265, 0.5316119, -0.18366171, -0.2584572, -0.11741771]), layer_2 = (weight = Float32[0.012355092 -0.06454273 … -0.042979304 0.19365568; -0.09335904 0.09502378 … 0.16607136 0.03768091], bias = Float32[-0.104326874, 0.11542177]))), SciMLBase.UDerivativeWrapper{false, SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3))), bias = ViewAxis(193:256, Shaped1DAxis((64,))))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64))), bias = ViewAxis(129:130, Shaped1DAxis((2,))))))}}}}(SciMLBase.ODEFunction{false, SciMLBase.FullSpecialize, DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}(DiffEqFlux.var"#f#37"{NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}, StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}}(NeuralODEMM{Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Main.var"#1#2", Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, @NamedTuple{saveat::Float64}}}(Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(3 => 64, tanh), layer_2 = Dense(64 => 2)), nothing), Main.var"#1#2"(), (0.0, 1.0), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], (OrdinaryDiffEqRosenbrock.Rodas5{0, AutoFiniteDiff{Val{:forward}, Val{:forward}, Val{:hcentral}, Nothing, Nothing, Int64}, Nothing, typeof(OrdinaryDiffEqCore.DEFAULT_PRECS), Val{:forward}(), true, nothing, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!)}(nothing, OrdinaryDiffEqCore.DEFAULT_PRECS, OrdinaryDiffEqCore.trivial_limiter!, OrdinaryDiffEqCore.trivial_limiter!, AutoFiniteDiff()),), Base.Pairs(:saveat => 0.1)), StatefulLuxLayer{Static.True, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Nothing, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}(Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(3 => 64, tanh), layer_2 = Dense(64 => 2)), nothing), nothing, (layer_1 = NamedTuple(), layer_2 = NamedTuple()), nothing, static(true))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing, nothing, nothing, nothing), 0.8885141749623799, (layer_1 = (weight = Float32[-1.1705778 -1.6588916 0.32729805; -0.28424123 -1.1089251 -1.4212596; … ; -1.6514931 0.7060482 -0.2694267; -1.0793198 -1.4750975 -1.3796407], bias = Float32[-0.29647824, -0.16483429, 0.45835268, 0.56277174, 0.09751839, -0.054524124, 0.34757167, -0.27479675, -0.014247845, 0.20296796  …  -0.28666598, -0.54918426, 0.22267614, -0.32032785, -0.14613543, -0.14987265, 0.5316119, -0.18366171, -0.2584572, -0.11741771]), layer_2 = (weight = Float32[0.012355092 -0.06454273 … -0.042979304 0.19365568; -0.09335904 0.09502378 … 0.16607136 0.03768091], bias = Float32[-0.104326874, 0.11542177]))), OrdinaryDiffEqRosenbrock.RodasTableau{Float64, Float64}([0.0 0.0 … 0.0 0.0; 2.0 0.0 … 0.0 0.0; … ; -14.09640773051259 6.925207756232704 … 0.0 0.0; -14.09640773051259 6.925207756232704 … 1.0 0.0], [0.0 0.0 … 0.0 0.0; -10.31323885133993 0.0 … 0.0 0.0; … ; 34.20013733472935 -14.1553540271769 … -6.551835421242162 0.0; 42.57076742291101 -13.80770672017997 … -6.685968952921985 -5.810979938412932], 0.19, [0.0, 0.38, 0.3878509998321533, 0.483971893787384, 0.457047700881958, 1.0, 1.0, 1.0], [0.19, -0.18230792253337147, -0.3192318321868749, 0.3449828624725343, -0.37741756439208984, 0.0, 0.0, 0.0], [27.354592673333357 -6.925207756232857 … -1.5306074446748028 -1.3929872940716344; 44.19024239501722 1.3677947663381929e-13 … 2.5793540257308067 2.2435122582734066; -44.0988150021747 -5.755396159656812e-13 … -5.792426076169686 -5.32503859794143]), [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], LinearAlgebra.LU{Float64, Matrix{Float64}, Vector{Int64}}(Matrix{Float64}(undef, 0, 0), Int64[], 0), nothing, AutoFiniteDiff(), 3), Bool[1, 1, 0], false), false, 0, SciMLBase.DEStats(146, 0, 12, 84, 12, 0, 0, 0, 0, 0, 12, 0, 0.0), nothing, SciMLBase.ReturnCode.Success, nothing, nothing, nothing), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Because this is a stiff problem, we have manually imposed that sum constraint via (u,p,t) -> [u[1] + u[2] + u[3] - 1], making the fitting easier.

Prediction Function

For simplicity, we define a wrapper function that only takes in the model's parameters to make predictions.

function predict_stiff_ndae(p)
    return model_stiff_ndae(u₀, p, st)[1]
end
predict_stiff_ndae (generic function with 1 method)

Train Parameters

Training our network requires a loss function, an optimizer, and a callback function to display the progress.

Loss

We first make our predictions based on the current parameters, then calculate the loss from these predictions. In this case, we use least squares as our loss.

function loss_stiff_ndae(p)
    pred = predict_stiff_ndae(p)
    loss = sum(abs2, sol_stiff .- pred)
    return loss
end

l1 = first(loss_stiff_ndae(ComponentArray(pinit)))
6.143780226745511

Notice that we are feeding the parameters of model_stiff_ndae to the loss_stiff_ndae function. model_stiff_node.p are the weights of our NN and is of size 386 (4 * 64 + 65 * 2) including the biases.

Optimizer

The optimizer is BFGS(see below).

Callback

The callback function displays the loss during training.

callback = function (state, l) #callback function to observe training
    display(l)
    return false
end
#3 (generic function with 1 method)

Train

Finally, training with Optimization.solve by passing: loss function, model parameters, optimizer, callback and maximum iteration.

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, OptimizationOptimJL.BFGS(); maxiters = 100)
retcode: Failure
u: ComponentVector{Float32}(layer_1 = (weight = Float32[-1.1709194 -1.6590341 0.32653773; -0.28648138 -1.1096169 -1.4201448; … ; -1.6512328 0.7061335 -0.26913196; -1.0794419 -1.4742478 -1.3785014], bias = Float32[-0.29758528, -0.16675988, 0.44270724, 0.5573102, 0.1109735, -0.048796244, 0.36303684, -0.26894683, -0.006725271, 0.19442889  …  -0.28426915, -0.54853666, 0.24116378, -0.33312234, -0.14456835, -0.1421789, 0.5386521, -0.17506234, -0.25794986, -0.1158998]), layer_2 = (weight = Float32[-0.02447303 -0.11081891 … -0.077704765 0.15714404; -0.094358325 0.07074579 … 0.164667 0.032916736], bias = Float32[-0.07756947, 0.10898841]))