Component based Universal Differential Equations with ModelingToolkit

ModelingToolkitNeuralNets provides 2 main interfaces for representing neural networks symbolically:

This tutorial will introduce the NeuralNetworkBlock. This representation is useful in the context of hierarchical acausal component-based model.

For such models we have a component representation that is converted to a a differential-algebraic equation (DAE) system, where the algebraic equations are given by the constraints and equalities between different component variables. The process of going from the component representation to the full DAE system at the end is referred to as structural simplification. In order to formulate Universal Differential Equations (UDEs) in this context, we could operate either operate before the structural simplification step or after that, on the resulting DAE system. We call these the component UDE formulation and the system UDE formulation.

The advantage of the component UDE formulation is that it allows us to represent the model discovery process in the block component formulation, allowing us to potentially target deeply nested parts of a model for the model discovery process. As such, this allows us to maximally reuse the existing information and also to have a result that can be expressed in the same manner as the original model.

In the following we will explore a simple example with a thermal model built with components from the ModelingToolkitStandardLibrary.

We will first start with a model that we generate synthetic data from and then we will build a simpler model that we will use for model discovery and try to recover the dynamics from the original model.

Our model represents a pot (HeatCapacitor) that is heated. In the complete model we consider that we have a plate that delays the heating of the pot. In the simple model we just consider that the heat source is directly connected to the pot.

Let's start with the definitions of the 2 models

using ModelingToolkitNeuralNets, Lux
using ModelingToolkitStandardLibrary.Thermal
using ModelingToolkitStandardLibrary.Blocks
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
using OrdinaryDiffEqTsit5
using Plots
using StableRNGs
using SciMLBase

input_f(t) = (1+sin(0.005 * t^2))/2

function PotWithPlate(; name)
    @parameters begin
        C1 = 1
        C2 = 15
    end

    @named input = Blocks.TimeVaryingFunction(f = input_f)
    @named source = PrescribedHeatFlow(T_ref = 373.15)
    @named plate = HeatCapacitor(C = C1, T = 273.15)
    @named pot = HeatCapacitor(C = C2, T = 273.15)
    @named conduction = ThermalConductor(G = 1)
    @named air = ThermalConductor(G = 0.1)
    @named env = FixedTemperature(T = 293.15)
    @named Tsensor = TemperatureSensor()

    eqs = [
        connect(input.output, :u, source.Q_flow),
        connect(source.port, plate.port),
        connect(plate.port, conduction.port_a),
        connect(conduction.port_b, pot.port),
        connect(pot.port, air.port_a),
        connect(air.port_b, env.port),
        connect(pot.port, Tsensor.port),
    ]

    System(eqs, t, [], [C1, C2]; name, systems = [input, source, plate, conduction, pot, air, env, Tsensor])
end
function SimplePot(; name)
    @parameters begin
        C2 = 15
    end

    @named input = Blocks.TimeVaryingFunction(f = input_f)
    @named source = PrescribedHeatFlow(T_ref = 373.15)
    @named pot = HeatCapacitor(C = C2, T = 273.15)
    @named air = ThermalConductor(G = 0.1)
    @named env = FixedTemperature(T = 293.15)
    @named Tsensor = TemperatureSensor()

    eqs = [
        connect(input.output, :u, source.Q_flow)
        connect(source.port, pot.port)
        connect(pot.port, Tsensor.port)
        connect(pot.port, air.port_a)
        connect(air.port_b, env.port)
    ]

    System(eqs, t, [], [C2]; name, systems = [input, source, pot, air, env, Tsensor])
end

@mtkcompile sys1 = PotWithPlate()
@mtkcompile sys2 = SimplePot()

## solve and plot the temperature of the pot in the 2 systems

prob1 = ODEProblem(sys1, Pair[], (0, 100.0))
sol1 = solve(prob1, Tsit5(), reltol = 1e-6)
prob2 = ODEProblem(sys2, Pair[], (0, 100.0))
sol2 = solve(prob2, Tsit5(), reltol = 1e-6)
plot(sol1, idxs = sys1.pot.T, label = "pot.T in original system")
plot!(sol2, idxs = sys1.pot.T, label = "pot.T in simplified system")
Example block output

If we take a closer look at the 2 models, the original system has 2 unknowns,

unknowns(sys1)
2-element Vector{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}}:
 pot₊T(t)
 plate₊T(t)

while the simplified system only has 1 unknown

unknowns(sys2)
1-element Vector{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}}:
 pot₊T(t)

With the component UDE approach, we want to add a new component in the model that would be connected between the source and the pot such that we can use it to discover the missing physics. To this end, we will build a ThermalNN component that adds a new state in the system (x) that is prescribed by the output of the neural network, but it also incorporates physics knoweledge about the system by formulating the component as something that outputs a heat flow rate based on some scaled input temperatures.

The system in this example is quite simple, so we will use a very small neural network with 1 hidden layer of size 4. For more complex dynamics, a larger network would be required. The inputs to out neural network will use scaled temperature values, such that we don't input numbers that would be too large for the chosen activation function. The chosen activation function will always output positive numbers for positive inputs, so this also makes physical sense for out component.

function ThermalNN(; name)
    n_input = 2
    n_output = 1
    chain = multi_layer_feed_forward(;
        n_input, n_output, depth = 1, width = 4, activation = Lux.swish)

    @named port_a = HeatPort()
    @named port_b = HeatPort()
    @named nn = NeuralNetworkBlock(; n_input, n_output, chain, rng = StableRNG(1337))

    @parameters begin
        T0 = 273.15
        T_range = 10
        C1 = 1
    end
    @variables begin
        dT(t), [guess = 0.0]
        Q_flow(t), [guess = 0.0]
        x(t) = T0
    end

    eqs = [
        dT ~ port_a.T - port_b.T
        port_a.Q_flow ~ Q_flow
        C1*D(x) ~ Q_flow - nn.outputs[1]
        port_a.T ~ x
        nn.outputs[1] + port_b.Q_flow ~ 0
        nn.inputs[1] ~ (x - T0) / T_range
        nn.inputs[2] ~ (port_b.T - T0) / T_range
    ]

    System(eqs, t; name, systems = [port_a, port_b, nn])
end

function NeuralPot(; name)
    @parameters begin
        C2 = 15
    end

    @named input = Blocks.TimeVaryingFunction(f = input_f)
    @named source = PrescribedHeatFlow(T_ref = 373.15)
    @named pot = HeatCapacitor(C = C2, T = 273.15)
    @named air = ThermalConductor(G = 0.1)
    @named env = FixedTemperature(T = 293.15)
    @named Tsensor = TemperatureSensor()
    @named thermal_nn = ThermalNN()

    eqs = [
        connect(input.output, :u, source.Q_flow)
        connect(pot.port, Tsensor.port)
        connect(pot.port, air.port_a)
        connect(air.port_b, env.port)
        connect(source.port, thermal_nn.port_a)
        connect(thermal_nn.port_b, pot.port)
    ]

    System(eqs, t, [], [C2]; name, systems = [input, source, pot, air, env, Tsensor, thermal_nn])
end

@named model = NeuralPot()
sys3 = mtkcompile(model)

# Let's check that we can successfully simulate the system in the
# initial state
prob3 = ODEProblem(sys3, Pair[], (0, 100.0))
sol3 = solve(prob3, Tsit5(), abstol = 1e-6, reltol = 1e-6)
@assert SciMLBase.successful_retcode(sol3)

Now that we have the system with the embedded neural network, we can start training the network. The training will be formulated as an optimization problem where we will minimize the mean absolute squared distance between the predictions of the new system and the data obtained from the original system. In order to gain some insight into the training process we will also add a callback that will plot various quantities in the system versus their equivalents in the original system. In a more realistic scenario we would not have access to the original system, but we could still monitor how well we fit the training data and the system predictions.

using SymbolicIndexingInterface
using Optimization
using OptimizationOptimJL
using LineSearches
using Statistics
using SciMLSensitivity
import Zygote

tp = Symbolics.scalarize(sys3.thermal_nn.nn.p)
x0 = prob3.ps[tp]

oop_update = setsym_oop(prob3, tp);

plot_cb = (opt_state,
    loss) -> begin
    opt_state.iter % 1000 ≠ 0 && return false
    @info "step $(opt_state.iter), loss: $loss"

    (new_u0, new_p) = oop_update(prob3, opt_state.u)
    new_prob = remake(prob3, u0 = new_u0, p = new_p)
    sol = solve(new_prob, Tsit5(), abstol = 1e-8, reltol = 1e-8)

    plt = plot(sol,
        layout = (2, 3),
        idxs = [
            sys3.thermal_nn.nn.inputs[1], sys3.thermal_nn.x,
            sys3.thermal_nn.nn.outputs[1], sys3.thermal_nn.port_b.T,
            sys3.pot.T, sys3.pot.port.Q_flow],
        size = (950, 800))
    plot!(plt,
        sol1,
        idxs = [
            (sys1.conduction.port_a.T-273.15)/10, sys1.conduction.port_a.T,
            sys1.conduction.port_a.Q_flow, sys1.conduction.port_b.T,
            sys1.pot.T, sys1.pot.port.Q_flow])
    display(plt)
    false
end

function cost(x, opt_ps)
    prob, oop_update, data, ts, get_T = opt_ps

    u0, p = oop_update(prob, x)
    new_prob = remake(prob; u0, p)

    new_sol = solve(new_prob, Tsit5(), saveat = ts, abstol = 1e-8,
        reltol = 1e-8, verbose = false, sensealg = GaussAdjoint())

    !SciMLBase.successful_retcode(new_sol) && return Inf

    mean(abs2.(get_T(new_sol) .- data))
end

of = OptimizationFunction(cost, AutoForwardDiff())

data = sol1[sys1.pot.T]
get_T = getsym(prob3, sys3.pot.T)
opt_ps = (prob3, oop_update, data, sol1.t, get_T);

op = OptimizationProblem(of, x0, opt_ps)

res = solve(op, Adam(); maxiters = 10_000, callback = plot_cb)
op2 = OptimizationProblem(of, res.u, opt_ps)
res2 = solve(op2, LBFGS(linesearch = BackTracking()); maxiters = 2000, callback = plot_cb)

(new_u0, new_p) = oop_update(prob3, res2.u)
new_prob1 = remake(prob3, u0 = new_u0, p = new_p)
new_sol1 = solve(new_prob1, Tsit5(), abstol = 1e-6, reltol = 1e-6)

plt = plot(new_sol1,
    layout = (2, 3),
    idxs = [
        sys3.thermal_nn.nn.inputs[1], sys3.thermal_nn.x,
        sys3.thermal_nn.nn.outputs[1], sys3.thermal_nn.port_b.T,
        sys3.pot.T, sys3.pot.port.Q_flow],
    size = (950, 800))
plot!(plt,
    sol1,
    idxs = [
        (sys1.conduction.port_a.T-273.15)/10, sys1.conduction.port_a.T,
        sys1.conduction.port_a.Q_flow, sys1.conduction.port_b.T,
        sys1.pot.T, sys1.pot.port.Q_flow],
    ls = :dash)
Example block output

As we can see from the final plot, the neural network fits very well and not only the training data fits, but also the rest of the predictions of the system match the original system. Let us also compare against the predictions of the incomplete system:

plot(sol1, label = ["original sys: pot T" "original sys: plate T"], lw = 3)
plot!(sol3; idxs = [sys3.pot.T], label = "untrained UDE", lw = 2.5)
plot!(sol2; idxs = [sys2.pot.T], label = "incomplete sys: pot T", lw = 2.5)
plot!(new_sol1; idxs = [sys3.pot.T, sys3.thermal_nn.x],
    label = "trained UDE", ls = :dash, lw = 2.5)
Example block output

Now that our neural network is trained, we can go a step further and use SymbolicRegression.jl to find a symbolic expression for the function represented by the neural network.

using SymbolicRegression

lux_model = new_sol1.ps[sys3.thermal_nn.nn.lux_model]
nn_p = new_sol1.ps[sys3.thermal_nn.nn.p]
T = new_sol1.ps[sys3.thermal_nn.nn.T]

sr_input = reduce(hcat, new_sol1[sys3.thermal_nn.nn.inputs])
sr_output = LuxCore.stateless_apply(lux_model, sr_input, convert(T, nn_p))

equation_search(sr_input, sr_output)
HallOfFame{...}:
    .exists[1] = true
    .members[1] = PopMember(tree = (0.3931390831449485), loss = 0.07455902083596212, cost = 0.3254184682212249)
    .exists[2] = false
    .members[2] = undef
    .exists[3] = true
    .members[3] = PopMember(tree = (1.0051801930129367 * 0.3911229278019697), loss = 0.07455902093465372, cost = 0.32541846865197177)
    .exists[4] = false
    .members[4] = undef
    .exists[5] = true
    .members[5] = PopMember(tree = ((x1 - x2) * 10.000652624854128), loss = 1.0494340764867852e-6, cost = 4.580334154613346e-6)
    .exists[6] = false
    .members[6] = undef
    .exists[7] = true
    .members[7] = PopMember(tree = (((x1 * 0.9999832867951277) - x2) * 10.002762217654158), loss = 1.0417653292618234e-6, cost = 4.546863329123117e-6)
    .exists[8] = false
    .members[8] = undef
    .exists[9] = true
    .members[9] = PopMember(tree = ((x1 - x2) * (x1 + (9.935653461169046 - x2))), loss = 4.831663500617063e-7, cost = 2.108815965797704e-6)
    .exists[10] = false
    .members[10] = undef
    .exists[11] = true
    .members[11] = PopMember(tree = ((x1 + (7.833841598142112e-5 - x2)) * ((x1 + 9.922236358342964) - x2)), loss = 2.8657968074810164e-7, cost = 1.250798624857926e-6)
    .exists[12] = false
    .members[12] = undef
    .exists[13] = true
    .members[13] = PopMember(tree = (((x1 - x2) * (((x1 - x2) * 1.5195238057149922) + 9.881301098173966)) - -0.0011994806435123519), loss = 1.8565587628014858e-7, cost = 8.103090705586996e-7)
    .exists[14] = false
    .members[14] = undef
    .exists[15] = true
    .members[15] = PopMember(tree = ((((0.9333093717176938 - ((x1 - x2) * -1.519523826610761)) + 8.947991724925291) * (x1 - x2)) - -0.0011994806550633726), loss = 1.8565587628015475e-7, cost = 8.103090705587266e-7)
    .exists[16] = false
    .members[16] = undef
    .exists[17] = true
    .members[17] = PopMember(tree = (((x1 + 0.00013773402982476688) - x2) * (((x1 * ((x1 - -0.939264402521055) - x2)) + 9.908420613766346) - x2)), loss = 1.381475090771385e-7, cost = 6.029552197495681e-7)
    .exists[18] = false
    .members[18] = undef
    .exists[19] = true
    .members[19] = PopMember(tree = (((x1 - x2) + 0.00012652252396514545) * ((9.930153091296502 - x2) + (x1 * ((x1 + 0.8669208578668843) - (x2 * 0.9450374495717463))))), loss = 7.036845862688737e-8, cost = 3.0712844348950107e-7)
    .exists[20] = false
    .members[20] = undef
    .exists[21] = true
    .members[21] = PopMember(tree = ((((x1 - x2) * 1.662343591504931) + 9.873391715383603) * ((x1 - -0.00022257643343216079) - (x2 + ((x1 * (x2 - 1.3018731662426826)) * -0.0003291895005916374)))), loss = 2.3491494234544257e-8, cost = 1.0253039785557266e-7)
    .exists[22] = false
    .members[22] = undef
    .exists[23] = true
    .members[23] = PopMember(tree = ((((((1.1274919520712294 - x1) * ((0.2249869839991028 - x1) * 0.0003274432357092217)) - x2) + x1) * (((x1 - x2) * 1.6543806804678638) + 9.871788520138661)) - -0.001475247272023147), loss = 1.9811069773205747e-8, cost = 8.64669077884522e-8)
    .exists[24] = false
    .members[24] = undef
    .exists[25] = true
    .members[25] = PopMember(tree = ((((x1 + (((0.22467644958037278 - x1) * 0.00034416865147348963) * ((1.1271808833847403 - x1) / 1.0527196640664493))) - x2) * (((x1 - x2) * 1.6427677605054836) + 9.873033667736886)) - -0.0014524725903929782), loss = 1.9995781414342143e-8, cost = 8.727309567353026e-8)
    .exists[26] = false
    .members[26] = undef
    .exists[27] = true
    .members[27] = PopMember(tree = ((((((0.2106472447919456 - x1) * (0.0004358382736906868 - (x2 * 0.00021705570619782379))) * (1.1134707091987337 - x1)) - x2) + (x1 - -0.00013682803010475994)) * (((x1 - x2) * 1.6277843346629088) + 9.875024295537356)), loss = 1.5304703016660842e-8, cost = 6.679853029749439e-8)
    .exists[28] = false
    .members[28] = undef
    .exists[29] = true
    .members[29] = PopMember(tree = (((((((0.2105671615663186 - x1) * 0.0004623260248170389) - ((x2 * x2) * -0.00020656432822272941)) * (1.1134293119900733 - x1)) + (x1 - x2)) - -0.00013266285276304223) * (((x1 - x2) * 1.62785306970639) + 9.875437719091375)), loss = 1.4106037792092689e-8, cost = 6.15668655449863e-8)
    .exists[30] = false
    .members[30] = undef

Looking at the results above, we see that we have a term linear in the difference between its inputs, x₁ - x₂, which were the scaled temperature values. This also makes sense physically, as the heat flow rate through the missing components in our system is proportional with the temperature difference between the ports. Having the symbolic expression for the UDE results thus helps us understand the missing physics better.