Enforcing Physical Constraints via Universal Differential-Algebraic Equations
As shown in the stiff ODE tutorial, differential-algebraic equations (DAEs) can be used to impose physical constraints. One way to define a DAE is through an ODE with a singular mass matrix. For example, if we make Mu' = f(u)
where the last row of M
is all zeros, then we have a constraint defined by the right hand side. Using NeuralODEMM
, we can use this to define a neural ODE where the sum of all 3 terms must add to one. An example of this is as follows:
using Lux, DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, Plots
using Random
rng = Random.default_rng()
function f!(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
du[1] = -k₁*y₁ + k₃*y₂*y₃
du[2] = k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
du[3] = y₁ + y₂ + y₃ - 1
return nothing
end
u₀ = [1.0, 0, 0]
M = [1. 0 0
0 1. 0
0 0 0]
tspan = (0.0,1.0)
p = [0.04, 3e7, 1e4]
stiff_func = ODEFunction(f!, mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
Lux.Dense(64, 2))
pinit, st = Lux.setup(rng, nn_dudt2)
model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
tspan, M, Rodas5(autodiff=false), saveat = 0.1)
model_stiff_ndae(u₀, Lux.ComponentArray(pinit), st)
function predict_stiff_ndae(p)
return model_stiff_ndae(u₀, p, st)[1]
end
function loss_stiff_ndae(p)
pred = predict_stiff_ndae(p)
loss = sum(abs2, Array(sol_stiff) .- pred)
return loss, pred
end
# callback = function (p, l, pred) #callback function to observe training
# display(l)
# return false
# end
l1 = first(loss_stiff_ndae(Lux.ComponentArray(pinit)))
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, BFGS(), maxiters=100)
u: ComponentVector{Float32}(layer_1 = (weight = Float32[-0.14778994 0.14410302 -0.039370313; 0.20214434 -0.16115378 0.02419426; … ; -0.23397037 0.20246558 -0.18861394; 0.12609924 -0.15247425 0.12415691], bias = Float32[-0.0072882273; 0.03228126; … ; -0.056865353; 0.022222443;;]), layer_2 = (weight = Float32[0.04378997 0.22626679 … -0.011999203 -0.2894333; 0.16644162 0.17503448 … -0.25506246 -0.14278124], bias = Float32[0.123588264; 0.1098209;;]))
Step-by-Step Description
Load Packages
using Lux, DiffEqFlux, Optimization, OptimizationOptimJL, DifferentialEquations, Plots
using Random
rng = Random.default_rng()
Random.TaskLocalRNG()
Differential Equation
First, we define our differential equations as a highly stiff problem which makes the fitting difficult.
function f!(du, u, p, t)
y₁, y₂, y₃ = u
k₁, k₂, k₃ = p
du[1] = -k₁*y₁ + k₃*y₂*y₃
du[2] = k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
du[3] = y₁ + y₂ + y₃ - 1
return nothing
end
f! (generic function with 1 method)
Parameters
u₀ = [1.0, 0, 0]
M = [1. 0 0
0 1. 0
0 0 0]
tspan = (0.0,1.0)
p = [0.04, 3e7, 1e4]
3-element Vector{Float64}:
0.04
3.0e7
10000.0
u₀
= Initial ConditionsM
= Semi-explicit Mass Matrix (last row is the constraint equation and are therefore
all zeros)
tspan
= Time span over which to evaluatep
= parametersk1
,k2
andk3
of the differential equation above
ODE Function, Problem and Solution
We define and solve our ODE problem to generate the "labeled" data which will be used to train our Neural Network.
stiff_func = ODEFunction(f!, mass_matrix = M)
prob_stiff = ODEProblem(stiff_func, u₀, tspan, p)
sol_stiff = solve(prob_stiff, Rodas5(), saveat = 0.1)
retcode: Success
Interpolation: 1st order linear
t: 11-element Vector{Float64}:
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
1.0
u: 11-element Vector{Vector{Float64}}:
[1.0, 0.0, 0.0]
[0.9960777474341889, 3.5804372328739174e-5, 0.003886448193482536]
[0.9923059457218133, 3.512303015079638e-5, 0.007658931248036001]
[0.9886739385487276, 3.4477160464978214e-5, 0.011291584290807323]
[0.9851721109941391, 3.386396553552364e-5, 0.01479402504032534]
[0.9817917747099651, 3.328089042275513e-5, 0.018174944399613487]
[0.9785250342445795, 3.2725768110280034e-5, 0.021442239987310215]
[0.9753647131269014, 3.2196529785412034e-5, 0.02460309034331384]
[0.9723042979019034, 3.169123899638582e-5, 0.027664010859099343]
[0.9693377993879712, 3.120829683456607e-5, 0.030630992315192573]
[0.966459738805013, 3.0746266110151764e-5, 0.033509514928876834]
Because this is a DAE we need to make sure to use a compatible solver. Rodas5
works well for this example.
Neural Network Layers
Next, we create our layers using Lux.Chain
. We use this instead of Flux.Chain
because it is more suited to SciML applications (similarly for Lux.Dense
). The input to our network will be the initial conditions fed in as u₀
.
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
Lux.Dense(64, 2))
pinit, st = Lux.setup(rng, nn_dudt2)
model_stiff_ndae = NeuralODEMM(nn_dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1],
tspan, M, Rodas5(autodiff=false), saveat = 0.1)
model_stiff_ndae(u₀, Lux.ComponentArray(pinit), st)
(SciMLBase.ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}, SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.Rodas5{3, false, LinearSolve.GenericLUFactorization{LinearAlgebra.RowMaximum}, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Nothing}}, DiffEqBase.DEStats}([[1.0, 0.0, 0.0], [0.9816976655472842, -0.01896731548620277, 0.03726964993891875], [0.9629783942251191, -0.0383626337849838, 0.07538423955986451], [0.9438279029679155, -0.05819708613619204, 0.11436918316827646], [0.9242318115915826, -0.07848209326508564, 0.1542502816735033], [0.9041758787907376, -0.09922935508790211, 0.19505347629716488], [0.8836460106289883, -0.12045085027556049, 0.23680483964657234], [0.8626282811459252, -0.14215883361555787, 0.27953055246963276], [0.8411091300129556, -0.16436581053772364, 0.323256680524768], [0.8190754036421792, -0.1870845351412067, 0.36800913149902775], [0.796514551488686, -0.21032798013602405, 0.41381342864733794]], nothing, nothing, [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [[[1.0, 0.0, 0.0]]], SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}, SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}(SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}(DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM(), Core.Box((layer_1 = NamedTuple(), layer_2 = NamedTuple()))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing), [1.0, 0.0, 0.0], (0.0, 1.0), (layer_1 = (weight = Float32[0.061940413 -0.24667129 -0.15563992; -0.049222913 -0.070505455 0.27129576; … ; -0.118842274 0.015148022 0.026080197; 0.0397433 0.08972653 -0.27939224], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.028946837 0.021934222 … -0.24127822 0.21112575; 0.012224163 -0.11326178 … 0.07030553 -0.27230176], bias = Float32[0.0; 0.0;;])), Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}(), SciMLBase.StandardODEProblem()), OrdinaryDiffEq.Rodas5{3, false, LinearSolve.GenericLUFactorization{LinearAlgebra.RowMaximum}, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}(LinearSolve.GenericLUFactorization{LinearAlgebra.RowMaximum}(LinearAlgebra.RowMaximum()), OrdinaryDiffEq.DEFAULT_PRECS), OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Nothing}}(SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}(DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM(), Core.Box((layer_1 = NamedTuple(), layer_2 = NamedTuple()))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing), [[1.0, 0.0, 0.0], [0.9816976655472842, -0.01896731548620277, 0.03726964993891875], [0.9629783942251191, -0.0383626337849838, 0.07538423955986451], [0.9438279029679155, -0.05819708613619204, 0.11436918316827646], [0.9242318115915826, -0.07848209326508564, 0.1542502816735033], [0.9041758787907376, -0.09922935508790211, 0.19505347629716488], [0.8836460106289883, -0.12045085027556049, 0.23680483964657234], [0.8626282811459252, -0.14215883361555787, 0.27953055246963276], [0.8411091300129556, -0.16436581053772364, 0.323256680524768], [0.8190754036421792, -0.1870845351412067, 0.36800913149902775], [0.796514551488686, -0.21032798013602405, 0.41381342864733794]], [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [[[1.0, 0.0, 0.0]]], false, OrdinaryDiffEq.Rosenbrock5ConstantCache{SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}, OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}, Matrix{Float64}, LinearAlgebra.LU{Float64, Matrix{Float64}}, Nothing}(SciMLBase.TimeDerivativeWrapper{SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Float64}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}(SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}(DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM(), Core.Box((layer_1 = NamedTuple(), layer_2 = NamedTuple()))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing), [1.0, 0.0, 0.0], (layer_1 = (weight = Float32[0.061940413 -0.24667129 -0.15563992; -0.049222913 -0.070505455 0.27129576; … ; -0.118842274 0.015148022 0.026080197; 0.0397433 0.08972653 -0.27939224], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.028946837 0.021934222 … -0.24127822 0.21112575; 0.012224163 -0.11326178 … 0.07030553 -0.27230176], bias = Float32[0.0; 0.0;;]))), SciMLBase.UDerivativeWrapper{SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Float64, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:256, Axis(weight = ViewAxis(1:192, ShapedAxis((64, 3), NamedTuple())), bias = ViewAxis(193:256, ShapedAxis((64, 1), NamedTuple())))), layer_2 = ViewAxis(257:386, Axis(weight = ViewAxis(1:128, ShapedAxis((2, 64), NamedTuple())), bias = ViewAxis(129:130, ShapedAxis((2, 1), NamedTuple())))))}}}}(SciMLBase.ODEFunction{false, DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}, Matrix{Float64}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}(DiffEqFlux.var"#f#193"{DiffEqFlux.NeuralODEMM{Lux.Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Lux.Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Main.var"#1#2", Nothing, Nothing, Tuple{Float64, Float64}, Matrix{Float64}, Tuple{OrdinaryDiffEq.Rodas5{0, false, Nothing, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing}}, Base.Pairs{Symbol, Float64, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{Float64}}}}}(NeuralODEMM(), Core.Box((layer_1 = NamedTuple(), layer_2 = NamedTuple()))), [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0], nothing, DiffEqFlux.basic_tgrad, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED, nothing), 0.6393120435468241, (layer_1 = (weight = Float32[0.061940413 -0.24667129 -0.15563992; -0.049222913 -0.070505455 0.27129576; … ; -0.118842274 0.015148022 0.026080197; 0.0397433 0.08972653 -0.27939224], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.028946837 0.021934222 … -0.24127822 0.21112575; 0.012224163 -0.11326178 … 0.07030553 -0.27230176], bias = Float32[0.0; 0.0;;]))), OrdinaryDiffEq.Rodas5Tableau{Float64, Float64}(2.0, 3.040894194418781, 1.041747909077569, 2.576417536461461, 1.62208306077664, -0.9089668560264532, 2.760842080225597, 1.446624659844071, -0.3036980084553738, 0.2877498600325443, -14.09640773051259, 6.925207756232704, -41.47510893210728, 2.343771018586405, 24.13215229196062, -10.31323885133993, -21.04823117650003, -7.234992135176716, 32.22751541853323, -4.943732386540191, 19.44922031041879, -20.69865579590063, -8.816374604402768, 1.260436877740897, -0.7495647613787146, -46.22004352711257, -17.49534862857472, -289.6389582892057, 93.60855400400906, 318.3822534212147, 34.20013733472935, -14.1553540271769, 57.823356409884, 25.83362985412365, 1.408950972071624, -6.551835421242162, 42.57076742291101, -13.80770672017997, 93.98938432427124, 18.77919633714503, -31.5835918722337, -6.685968952921985, -5.810979938412932, 0.19, 0.19, -0.18230792253337147, -0.3192318321868749, 0.3449828624725343, -0.37741756439208984, 0.38, 0.3878509998321533, 0.483971893787384, 0.457047700881958, 27.354592673333357, -6.925207756232857, 26.40037733258859, 0.5635230501052979, -4.699151156849391, -1.6008677469422725, -1.5306074446748028, -1.3929872940716344, 44.19024239501722, 1.3677947663381929e-13, 202.93261852171622, -35.5669339789154, -181.91095152160645, 3.4116351403665033, 2.5793540257308067, 2.2435122582734066, -44.0988150021747, -5.755396159656812e-13, -181.26175034586677, 56.99302194811676, 183.21182741427398, -7.480257918273637, -5.792426076169686, -5.32503859794143), [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], LinearAlgebra.LU{Float64, Matrix{Float64}}(Matrix{Float64}(undef, 0, 0), Int64[], 0), nothing)), false, 0, DiffEqBase.DEStats
Number of function 1 evaluations: 98
Number of function 2 evaluations: 0
Number of W matrix evaluations: 8
Number of linear solves: 64
Number of Jacobians created: 8
Number of nonlinear solver iterations: 0
Number of nonlinear solver convergence failures: 0
Number of rootfind condition calls: 0
Number of accepted steps: 8
Number of rejected steps: 0, :Success), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Because this is a stiff problem, we have manually imposed that sum constraint via (u,p,t) -> [u[1] + u[2] + u[3] - 1]
, making the fitting easier.
Prediction Function
For simplicity, we define a wrapper function that only takes in the model's parameters to make predictions.
function predict_stiff_ndae(p)
return model_stiff_ndae(u₀, p, st)[1]
end
predict_stiff_ndae (generic function with 1 method)
Train Parameters
Training our network requires a loss function, an optimizer and a callback function to display the progress.
Loss
We first make our predictions based on the current parameters, then calculate the loss from these predictions. In this case, we use least squares as our loss.
function loss_stiff_ndae(p)
pred = predict_stiff_ndae(p)
loss = sum(abs2, sol_stiff .- pred)
return loss, pred
end
l1 = first(loss_stiff_ndae(Lux.ComponentArray(pinit)))
0.7897276992692538
Notice that we are feeding the parameters of model_stiff_ndae
to the loss_stiff_ndae
function. model_stiff_node.p
are the weights of our NN and is of size 386 (4 * 64 + 65 * 2) including the biases.
Optimizer
The optimizer is BFGS
(see below).
Callback
The callback function displays the loss during training.
callback = function (p, l, pred) #callback function to observe training
display(l)
return false
end
#3 (generic function with 1 method)
Train
Finally, training with Optimization.solve
by passing: loss function, model parameters, optimizer, callback and maximum iteration.
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_stiff_ndae(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(pinit))
result_stiff = Optimization.solve(optprob, BFGS(), maxiters=100)
u: ComponentVector{Float32}(layer_1 = (weight = Float32[0.06272323 -0.24562438 -0.15333647; -0.052775245 -0.073353164 0.27237546; … ; -0.1155097 0.011523392 0.009404958; 0.032456394 0.08666926 -0.26522142], bias = Float32[0.0041229255; -0.0051113516; … ; -0.016886711; 0.0039278693;;]), layer_2 = (weight = Float32[0.011764122 0.038023714 … -0.23903331 0.19029555; 0.006457003 -0.11739719 … 0.06665023 -0.27227452], bias = Float32[0.06974203; 0.04399327;;]))