Learning Nonlinear Reaction Dynamics in the 2D Brusselator PDE Using Universal Differential Equations
Introduction
The Brusselator is a mathematical model used to describe oscillating chemical reactions and spatial pattern formation, capturing how concentrations of chemical species evolve over time and space. In this documentation, we simulate the two-dimensional Brusselator partial differential equation (PDE) on a periodic square domain, generate time-resolved data using a finite difference discretization, and use this data to train a Universal Differential Equation (UDE). Specifically, we replace the known nonlinear reaction term with a neural network, enabling us to learn complex dynamics directly from the generated data while preserving the known physical structure of the system.
The Brusselator PDE
The Brusselator PDE is defined on a unit square periodic domain as follows:
\[\frac{\partial U}{\partial t} = B + U^2V - (A+1)U + \alpha \nabla^2 U + f(x, y, t)\]
\[\frac{\partial V}{\partial t} = AU - U^2V + \alpha \nabla^2 V\]
where $A=3.4, B=1$ and the forcing term is:
\[f(x, y, t) = \begin{cases} 5 & \text{if } (x - 0.3)^2 + (y - 0.6)^2 \leq 0.1^2 \text{ and } t \geq 1.1 \\ 0 & \text{otherwise} \end{cases}\]
and the Laplacian operator is:
\[\nabla^2 = \frac{\partial^2}{\partial x^2} + \frac{\partial^2}{\partial y^2}\]
These equations are solved over the time interval:
\[t \in [0, 11.5]\]
with the initial conditions:
\[U(x, y, 0) = 22 \cdot \left( y(1 - y) \right)^{3/2}\]
\[V(x, y, 0) = 27 \cdot \left( x(1 - x) \right)^{3/2}\]
and the periodic boundary conditions:
\[U(x + 1, y, t) = U(x, y, t)\]
\[V(x, y + 1, t) = V(x, y, t)\]
Numerical Discretization
To numerically solve this PDE, we discretize the unit square domain using $N$ grid points along each spatial dimension. The variables $U[i,j]$ and $V[i,j]$ then denote the concentrations at the grid point $(i, j)$ at a given time $t$.
We represent the spatially discretized fields as:
\[U[i,j] = U(i \cdot \Delta x, j \cdot \Delta y), \quad V[i,j] = V(i \cdot \Delta x, j \cdot \Delta y),\]
where $\Delta x = \Delta y = \frac{1}{N}$ for a grid of size $N \times N$. To organize the simulation state efficiently, we store both $ U $ and $ V $ in a single 3D array:
\[u[i,j,1] = U[i,j], \quad u[i,j,2] = V[i,j],\]
giving us a field tensor of shape $(N, N, 2)$. This structure is flexible and extends naturally to systems with additional field variables.
Finite Difference Laplacian and Forcing
For spatial derivatives, we apply a second-order central difference scheme using a three-point stencil. The Laplacian is discretized as:
\[[\ 1,\ -2,\ 1\ ]\]
in both the $ x $ and $ y $ directions, forming a tridiagonal structure in both the x and y directions; applying this 1D stencil (scaled appropriately by $\frac{1}{Δx^2}$ or $\frac{1}{Δy^2}$) along each axis and summing the contributions yields the standard 5-point stencil computation for the 2D Laplacian. Periodic boundary conditions are incorporated by wrapping the stencil at the domain edges, effectively connecting the boundaries. The nonlinear interaction terms are computed directly at each grid point, making the implementation straightforward and local in nature.
Generating Training Data
This provides us with an ODEProblem
that can be solved to obtain training data.
using ComponentArrays, Random, Plots, OrdinaryDiffEq
N_GRID = 16
XYD = range(0f0, stop = 1f0, length = N_GRID)
dx = step(XYD)
T_FINAL = 11.5f0
SAVE_AT = 0.5f0
tspan = (0.0f0, T_FINAL)
t_points = range(tspan[1], stop=tspan[2], step=SAVE_AT)
A, B, alpha = 3.4f0, 1.0f0, 10.0f0
brusselator_f(x, y, t) = (((x - 0.3f0)^2 + (y - 0.6f0)^2) <= 0.01f0) * (t >= 1.1f0) * 5.0f0
limit(a, N) = a == 0 ? N : a == N+1 ? 1 : a
function init_brusselator(xyd)
println("[Init] Creating initial condition array...")
u0 = zeros(Float32, N_GRID, N_GRID, 2)
for I in CartesianIndices((N_GRID, N_GRID))
x, y = xyd[I[1]], xyd[I[2]]
u0[I,1] = 22f0 * (y * (1f0 - y))^(3f0/2f0)
u0[I,2] = 27f0 * (x * (1f0 - x))^(3f0/2f0)
end
println("[Init] Done.")
return u0
end
u0 = init_brusselator(XYD)
function pde_truth!(du, u, p, t)
A, B, alpha, dx = p
αdx = alpha / dx^2
for I in CartesianIndices((N_GRID, N_GRID))
i, j = Tuple(I)
x, y = XYD[i], XYD[j]
ip1, im1 = limit(i+1, N_GRID), limit(i-1, N_GRID)
jp1, jm1 = limit(j+1, N_GRID), limit(j-1, N_GRID)
U, V = u[i,j,1], u[i,j,2]
ΔU = u[im1,j,1] + u[ip1,j,1] + u[i,jp1,1] + u[i,jm1,1] - 4f0 * U
ΔV = u[im1,j,2] + u[ip1,j,2] + u[i,jp1,2] + u[i,jm1,2] - 4f0 * V
du[i,j,1] = αdx*ΔU + B + U^2 * V - (A+1f0)*U + brusselator_f(x, y, t)
du[i,j,2] = αdx*ΔV + A*U - U^2 * V
end
end
p_tuple = (A, B, alpha, dx)
@time sol_truth = solve(ODEProblem(pde_truth!, u0, tspan, p_tuple), FBDF(), saveat=t_points)
u_true = Array(sol_truth)
16×16×2×24 Array{Float32, 4}:
[:, :, 1, 1] =
0.0 0.341461 0.864189 1.408 1.90251 … 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 … 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 … 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 … 1.408 0.864189 0.341461 0.0
[:, :, 2, 1] =
0.0 0.0 0.0 0.0 … 0.0 0.0 0.0
0.419066 0.419066 0.419066 0.419066 0.419066 0.419066 0.419066
1.0606 1.0606 1.0606 1.0606 1.0606 1.0606 1.0606
1.728 1.728 1.728 1.728 1.728 1.728 1.728
2.3349 2.3349 2.3349 2.3349 2.3349 2.3349 2.3349
2.82843 2.82843 2.82843 2.82843 … 2.82843 2.82843 2.82843
3.17454 3.17454 3.17454 3.17454 3.17454 3.17454 3.17454
3.35252 3.35252 3.35252 3.35252 3.35252 3.35252 3.35252
3.35252 3.35252 3.35252 3.35252 3.35252 3.35252 3.35252
3.17454 3.17454 3.17454 3.17454 3.17454 3.17454 3.17454
2.82843 2.82843 2.82843 2.82843 … 2.82843 2.82843 2.82843
2.3349 2.3349 2.3349 2.3349 2.3349 2.3349 2.3349
1.728 1.728 1.728 1.728 1.728 1.728 1.728
1.0606 1.0606 1.0606 1.0606 1.0606 1.0606 1.0606
0.419066 0.419066 0.419066 0.419066 0.419066 0.419066 0.419066
0.0 0.0 0.0 0.0 … 0.0 0.0 0.0
[:, :, 1, 2] =
0.900959 0.900958 0.900906 0.900949 … 0.900975 0.900924 0.900967
0.900941 0.900958 0.900952 0.900947 0.900935 0.90096 0.900951
0.900942 0.900923 0.900957 0.900939 0.90094 0.900916 0.900951
0.900966 0.900959 0.900941 0.900932 0.90093 0.900967 0.900914
0.900952 0.900941 0.900915 0.900955 0.900931 0.900974 0.90094
0.900939 0.900921 0.900929 0.900972 … 0.900922 0.90093 0.900941
0.900939 0.90093 0.900938 0.900971 0.900914 0.90094 0.900956
0.900932 0.900905 0.900947 0.900972 0.900955 0.900931 0.900904
0.900922 0.900922 0.900922 0.900955 0.900947 0.900905 0.900939
0.900958 0.900947 0.90093 0.900936 0.900931 0.900947 0.900915
0.900905 0.90093 0.900941 0.900953 … 0.900904 0.900932 0.900913
0.900933 0.900932 0.900957 0.900956 0.900949 0.900958 0.900959
0.900966 0.900942 0.900931 0.900916 0.900932 0.90095 0.900949
0.900942 0.900958 0.900938 0.900943 0.900957 0.900975 0.900949
0.900941 0.90096 0.900991 0.90095 0.900916 0.900967 0.90096
0.90096 0.90095 0.90096 0.900949 … 0.900934 0.90096 0.900976
[:, :, 2, 2] =
2.38291 2.38295 2.38301 2.38301 … 2.38304 2.38298 2.38301 2.38305
2.38308 2.38298 2.38301 2.38301 2.38298 2.38305 2.38308 2.38294
2.38305 2.38311 2.38301 2.38302 2.38301 2.38305 2.38301 2.38305
2.38308 2.383 2.38308 2.38305 2.38305 2.38308 2.38295 2.38298
2.38308 2.38301 2.38308 2.38298 2.38308 2.38301 2.38298 2.38298
2.38307 2.38302 2.38304 2.38308 … 2.38299 2.38301 2.38305 2.38308
2.38307 2.38302 2.38308 2.38302 2.38298 2.38308 2.38307 2.38305
2.38308 2.38307 2.38302 2.38311 2.38308 2.38301 2.38307 2.38302
2.38311 2.38301 2.38311 2.38295 2.38298 2.38305 2.38305 2.38302
2.38307 2.38301 2.38301 2.38305 2.38305 2.38308 2.38301 2.38308
2.38311 2.38305 2.38308 2.38305 … 2.38308 2.38298 2.38305 2.38304
2.38295 2.38305 2.38305 2.38301 2.38301 2.38301 2.38308 2.38298
2.38301 2.38298 2.38305 2.38305 2.38298 2.38301 2.38298 2.38305
2.38298 2.38305 2.38301 2.38305 2.38307 2.38301 2.38311 2.38301
2.38295 2.38305 2.38298 2.383 2.38301 2.38301 2.38305 2.38298
2.38308 2.38298 2.383 2.38312 … 2.38305 2.38297 2.38307 2.38298
[:, :, 1, 3] =
0.520114 0.520106 0.520109 0.520104 … 0.520105 0.52011 0.520109
0.520113 0.52011 0.52011 0.520115 0.520109 0.520109 0.520111
0.520109 0.520115 0.52011 0.520111 0.52011 0.520106 0.52011
0.52011 0.520109 0.520114 0.52011 0.52011 0.520114 0.520109
0.520109 0.520111 0.52011 0.520109 0.520109 0.520109 0.520111
0.520111 0.52011 0.52011 0.52011 … 0.520115 0.520109 0.52011
0.520106 0.520105 0.520111 0.52011 0.520109 0.520106 0.520109
0.52011 0.520106 0.520106 0.520107 0.520111 0.52011 0.520111
0.52011 0.520109 0.52011 0.52011 0.520111 0.520113 0.52011
0.520111 0.52011 0.520115 0.520115 0.52011 0.52011 0.52011
0.520113 0.520115 0.520109 0.520107 … 0.52011 0.52011 0.52011
0.520111 0.520106 0.52011 0.52011 0.52011 0.520109 0.52011
0.52011 0.520109 0.52011 0.520113 0.520107 0.52011 0.520106
0.52011 0.52011 0.52011 0.520106 0.520105 0.520106 0.52011
0.520113 0.52011 0.520114 0.520109 0.52011 0.520109 0.520109
0.52011 0.520114 0.520109 0.520114 … 0.520106 0.520114 0.520111
[:, :, 2, 3] =
2.91981 2.91982 2.91984 2.91982 … 2.91981 2.91984 2.9198 2.91984
2.9198 2.91984 2.91982 2.91983 2.91982 2.91982 2.9198 2.91984
2.91982 2.91983 2.91983 2.91981 2.9198 2.91982 2.91982 2.91982
2.91984 2.91983 2.91982 2.91982 2.91982 2.91984 2.91982 2.91982
2.91984 2.91981 2.91982 2.91983 2.91982 2.9198 2.91982 2.91982
2.91982 2.91983 2.91981 2.91982 … 2.91982 2.91984 2.91982 2.91983
2.91983 2.91983 2.91983 2.91979 2.91982 2.91982 2.91982 2.91984
2.91984 2.91983 2.91979 2.91981 2.9198 2.91982 2.91984 2.91982
2.9198 2.91981 2.91982 2.91982 2.9198 2.91982 2.91982 2.91984
2.91982 2.91983 2.91984 2.91986 2.91982 2.91982 2.91984 2.9198
2.91982 2.91983 2.91984 2.91984 … 2.91983 2.91981 2.91982 2.91984
2.91984 2.91984 2.91982 2.91982 2.91982 2.91983 2.91982 2.91983
2.91984 2.91982 2.91982 2.91979 2.91981 2.91981 2.91981 2.91982
2.91983 2.91981 2.91981 2.91982 2.91982 2.91981 2.91982 2.91982
2.91982 2.91982 2.91984 2.91982 2.91984 2.91983 2.91982 2.91983
2.91982 2.91982 2.91984 2.91982 … 2.91984 2.9198 2.9198 2.91983
;;;; …
[:, :, 1, 22] =
0.560204 0.560167 0.560194 0.560255 … 0.560535 0.560387 0.560265
0.560275 0.560254 0.560285 0.56038 0.560737 0.560524 0.56038
0.560391 0.560352 0.560381 0.560506 0.560954 0.560682 0.560503
0.560459 0.560405 0.560459 0.560598 0.561171 0.560811 0.560589
0.560498 0.560459 0.560505 0.560652 0.561314 0.56091 0.560659
0.560499 0.560474 0.560506 0.56066 … 0.561332 0.560933 0.560653
0.560483 0.560433 0.56048 0.560618 0.561231 0.560869 0.560616
0.560415 0.560374 0.560405 0.560538 0.56103 0.560731 0.560529
0.560323 0.560289 0.560326 0.560427 0.560819 0.560587 0.560417
0.560219 0.560211 0.560231 0.560309 0.560615 0.560432 0.560312
0.560152 0.56012 0.560149 0.560208 … 0.560428 0.560307 0.560209
0.560083 0.56006 0.560085 0.560125 0.560313 0.560219 0.560138
0.560044 0.560031 0.560044 0.560095 0.56023 0.560163 0.560085
0.56004 0.560024 0.56003 0.560084 0.56024 0.560151 0.560075
0.560068 0.560039 0.560056 0.560117 0.560276 0.56018 0.560115
0.560112 0.560102 0.560112 0.560168 … 0.560379 0.560281 0.560168
[:, :, 2, 22] =
4.97969 4.97971 4.97971 4.97971 … 4.97975 4.97977 4.97978 4.97978
4.9798 4.97971 4.97971 4.97972 4.97984 4.97971 4.9797 4.97978
4.9798 4.97974 4.9797 4.97985 4.97973 4.97978 4.97985 4.97972
4.9798 4.97978 4.97978 4.97987 4.97972 4.97977 4.97976 4.97979
4.97988 4.97978 4.97985 4.97971 4.97979 4.97981 4.97966 4.97969
4.97977 4.97988 4.97972 4.97977 … 4.97973 4.97975 4.97968 4.97979
4.97978 4.9798 4.97978 4.97977 4.97977 4.97983 4.97968 4.97977
4.9797 4.97978 4.9797 4.97977 4.97982 4.97978 4.97975 4.97987
4.97988 4.97978 4.9798 4.97978 4.97975 4.97976 4.97971 4.9798
4.97965 4.97971 4.97973 4.97978 4.97984 4.97969 4.97978 4.9798
4.97971 4.97979 4.97979 4.97979 … 4.97971 4.9798 4.97978 4.97971
4.97981 4.97971 4.97979 4.97987 4.97978 4.9798 4.97971 4.97987
4.97979 4.97979 4.97971 4.97971 4.97978 4.97978 4.97973 4.97971
4.97969 4.97979 4.97983 4.97973 4.97978 4.9798 4.97975 4.97987
4.97973 4.97979 4.97981 4.97973 4.97978 4.97972 4.97979 4.97963
4.97981 4.97973 4.97969 4.97981 … 4.97987 4.9797 4.97986 4.97981
[:, :, 1, 23] =
0.769321 0.769286 0.769321 0.7694 … 0.76967 0.76951 0.769397
0.769418 0.769383 0.769415 0.769504 0.769856 0.769662 0.769509
0.769512 0.769471 0.769507 0.769621 0.770083 0.76981 0.769626
0.769592 0.769541 0.769589 0.769724 0.7703 0.769953 0.769719
0.769636 0.769591 0.76963 0.769785 0.770445 0.770048 0.769786
0.769644 0.769596 0.769642 0.769796 … 0.770473 0.770061 0.769794
0.769615 0.769568 0.769608 0.769749 0.770368 0.770007 0.769757
0.769539 0.769503 0.769544 0.769664 0.770161 0.769877 0.769664
0.769451 0.769417 0.769451 0.769556 0.769954 0.769713 0.769564
0.769363 0.769328 0.769361 0.769437 0.769736 0.769569 0.769432
0.769276 0.769249 0.769278 0.769335 … 0.76957 0.76944 0.769341
0.769201 0.769185 0.769208 0.769254 0.769428 0.769344 0.769256
0.769169 0.769148 0.769172 0.769213 0.769367 0.76928 0.769213
0.76916 0.769144 0.769158 0.769203 0.769358 0.769272 0.769208
0.769189 0.769174 0.769186 0.769238 0.769411 0.76931 0.769238
0.76924 0.769228 0.76924 0.769299 … 0.769508 0.769389 0.769304
[:, :, 2, 23] =
5.0146 5.01464 5.0146 5.01464 … 5.01458 5.01461 5.01463 5.0146
5.01462 5.01466 5.0146 5.01459 5.01462 5.0146 5.01463 5.01457
5.01465 5.01462 5.01463 5.01465 5.01458 5.01462 5.01458 5.01465
5.01457 5.01463 5.01459 5.01463 5.01464 5.01459 5.01458 5.01459
5.01465 5.01459 5.01461 5.01464 5.01461 5.01458 5.01462 5.01458
5.01463 5.01469 5.01459 5.01464 … 5.01459 5.01464 5.01465 5.01462
5.01465 5.01453 5.01465 5.01462 5.01467 5.01461 5.01462 5.01459
5.01461 5.01462 5.01465 5.01461 5.01458 5.01465 5.01458 5.01465
5.01458 5.01466 5.01463 5.01465 5.01461 5.01458 5.01463 5.01463
5.0146 5.01456 5.01466 5.0146 5.01462 5.01463 5.01461 5.01464
5.01462 5.01469 5.01458 5.01462 … 5.01465 5.01459 5.01468 5.0146
5.01469 5.01461 5.01463 5.01462 5.01463 5.01466 5.0147 5.01462
5.01459 5.01467 5.01463 5.01463 5.0146 5.01462 5.01456 5.01463
5.01465 5.01463 5.01465 5.01463 5.0146 5.01464 5.01462 5.01461
5.01469 5.01465 5.01461 5.01462 5.01466 5.01464 5.01468 5.01466
5.01466 5.01465 5.0146 5.01466 … 5.01459 5.01461 5.0146 5.01466
[:, :, 1, 24] =
2.105 2.10497 2.105 2.10507 … 2.10553 2.10535 2.1052 2.10508
2.1051 2.10506 2.1051 2.10519 2.1058 2.10556 2.10535 2.10519
2.10519 2.10515 2.10519 2.10531 2.10613 2.10579 2.10551 2.10531
2.10528 2.10523 2.10528 2.10542 2.10648 2.10601 2.10566 2.10542
2.10533 2.10528 2.10533 2.10548 2.10676 2.10616 2.10575 2.10548
2.10533 2.10528 2.10533 2.10549 … 2.10681 2.10619 2.10577 2.10549
2.1053 2.10525 2.1053 2.10545 2.1066 2.10608 2.1057 2.10545
2.10523 2.10519 2.10523 2.10536 2.10627 2.10588 2.10557 2.10536
2.10514 2.1051 2.10514 2.10524 2.10593 2.10565 2.10541 2.10524
2.10504 2.10501 2.10504 2.10512 2.10563 2.10543 2.10526 2.10512
2.10495 2.10493 2.10495 2.10501 … 2.1054 2.10525 2.10512 2.10501
2.10488 2.10486 2.10488 2.10493 2.10523 2.10512 2.10502 2.10493
2.10484 2.10482 2.10484 2.10489 2.10514 2.10505 2.10496 2.10489
2.10483 2.10481 2.10483 2.10488 2.10513 2.10504 2.10495 2.10488
2.10486 2.10484 2.10486 2.10491 2.10519 2.10508 2.10499 2.10491
2.10492 2.1049 2.10492 2.10498 … 2.10532 2.10519 2.10507 2.10498
[:, :, 2, 24] =
3.65567 3.65568 3.65567 3.65567 … 3.65565 3.65566 3.65566 3.65567
3.65567 3.65567 3.65567 3.65566 3.65564 3.65565 3.65566 3.65566
3.65566 3.65566 3.65566 3.65566 3.65562 3.65564 3.65565 3.65566
3.65566 3.65566 3.65566 3.65565 3.65561 3.65563 3.65564 3.65565
3.65565 3.65566 3.65565 3.65565 3.65561 3.65562 3.65564 3.65565
3.65565 3.65565 3.65565 3.65565 … 3.65561 3.65562 3.65564 3.65565
3.65565 3.65566 3.65565 3.65565 3.65561 3.65562 3.65564 3.65565
3.65566 3.65566 3.65566 3.65565 3.65562 3.65563 3.65564 3.65565
3.65566 3.65567 3.65567 3.65566 3.65563 3.65564 3.65565 3.65566
3.65567 3.65567 3.65567 3.65567 3.65564 3.65565 3.65566 3.65567
3.65568 3.65568 3.65568 3.65567 … 3.65565 3.65566 3.65567 3.65567
3.65568 3.65568 3.65568 3.65568 3.65566 3.65567 3.65567 3.65568
3.65569 3.65569 3.65569 3.65568 3.65566 3.65567 3.65568 3.65568
3.65569 3.65569 3.65569 3.65568 3.65567 3.65567 3.65568 3.65568
3.65568 3.65569 3.65568 3.65568 3.65566 3.65567 3.65568 3.65568
3.65568 3.65568 3.65568 3.65568 … 3.65566 3.65566 3.65567 3.65568
Visualizing Mean Concentration Over Time
We can now use this code for training our UDE, and generating time-series plots of the concentrations of species of U and V using the code:
using Plots, Statistics
# Compute average concentration at each timestep
avg_U = [mean(snapshot[:, :, 1]) for snapshot in sol_truth.u]
avg_V = [mean(snapshot[:, :, 2]) for snapshot in sol_truth.u]
# Plot average concentrations over time
plot(sol_truth.t, avg_U, label="Mean U", lw=2, xlabel="Time", ylabel="Concentration",
title="Mean Concentration of U and V Over Time")
plot!(sol_truth.t, avg_V, label="Mean V", lw=2, linestyle=:dash)
With the ground truth data generated and visualized, we are now ready to construct a Universal Differential Equation (UDE) by replacing the nonlinear term $U^2V$ with a neural network. The next section outlines how we define this hybrid model and train it to recover the reaction dynamics from data.
Universal Differential Equation (UDE) Formulation
In the original Brusselator model, the nonlinear reaction term ( U^2V ) governs key dynamic behavior. In our UDE approach, we replace this known term with a trainable neural network ( \mathcal{N}_\theta(U, V) ), where ( \theta ) are the learnable parameters.
The resulting system becomes:
$
\frac{\partial U}{\partial t} = 1 + \mathcal{N}_\theta(U, V) - 4.4U + \alpha \nabla^2 U + f(x, y, t) $
$
\frac{\partial V}{\partial t} = 3.4U - \mathcal{N}_\theta(U, V) + \alpha \nabla^2 V $
Here, $\mathcal{N}_\theta(U, V)$ is trained to approximate the true interaction term $U^2V$ using simulation data. This hybrid formulation allows us to recover unknown or partially known physical processes while preserving the known structural components of the PDE.
First, we have to define and configure the neural network that has to be used for the training. The implementation for that is as follows:
using Lux, Random, Optimization, OptimizationOptimJL, SciMLSensitivity, Zygote
model = Lux.Chain(Dense(2 => 16, tanh), Dense(16 => 1))
rng = Random.default_rng()
ps_init, st = Lux.setup(rng, model)
ps_init = ComponentArray(ps_init)
ComponentVector{Float32}(layer_1 = (weight = Float32[-0.51828104 -0.569505; 1.0009451 0.8462184; … ; 1.4430382 -1.0838026; 0.69011223 0.3295439], bias = Float32[0.4597295, -0.61192226, -0.5055269, 0.16723245, 0.0009303495, -0.36198193, -0.43082693, 0.5786203, -0.51828796, -0.60951185, -0.12052009, 0.16507632, 0.34016326, -0.03494252, -0.10871948, 0.3553313]), layer_2 = (weight = Float32[-0.16546121 -0.33607212 … 0.32296163 0.26847205], bias = Float32[-0.02632606]))
We use a simple fully connected neural network with one hidden layer of 16 tanh-activated units to approximate the nonlinear interaction term.
To ensure consistency between the ground truth simulation and the learned Universal Differential Equation (UDE) model, we preserve the same spatial discretization scheme used in the original ODEProblem. This includes:
- the finite difference Laplacian,
- periodic boundary conditions, and
- the external forcing function.
The only change lies in the replacement of the known nonlinear term $U^2V$ with a neural network approximation $\mathcal{N}_\theta(U, V)$. This design enables the UDE to learn complex or unknown dynamics from data while maintaining the underlying physical structure of the system.
The function below implements this hybrid formulation:
function pde_ude!(du, u, ps_nn, t)
αdx = alpha / dx^2
for I in CartesianIndices((N_GRID, N_GRID))
i, j = Tuple(I)
x, y = XYD[i], XYD[j]
ip1, im1 = limit(i+1, N_GRID), limit(i-1, N_GRID)
jp1, jm1 = limit(j+1, N_GRID), limit(j-1, N_GRID)
U, V = u[i,j,1], u[i,j,2]
ΔU = u[im1,j,1] + u[ip1,j,1] + u[i,jp1,1] + u[i,jm1,1] - 4f0 * U
ΔV = u[im1,j,2] + u[ip1,j,2] + u[i,jp1,2] + u[i,jm1,2] - 4f0 * V
nn_val, _ = model([U, V], ps_nn, st)
val = nn_val[1]
du[i,j,1] = αdx*ΔU + B + val - (A+1f0)*U + brusselator_f(x, y, t)
du[i,j,2] = αdx*ΔV + A*U - val
end
end
prob_ude_template = ODEProblem(pde_ude!, u0, tspan, ps_init)
ODEProblem with uType Array{Float32, 3} and tType Float32. In-place: true
Non-trivial mass matrix: false
timespan: (0.0f0, 11.5f0)
u0: 16×16×2 Array{Float32, 3}:
[:, :, 1] =
0.0 0.341461 0.864189 1.408 1.90251 … 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 … 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 … 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 1.408 0.864189 0.341461 0.0
0.0 0.341461 0.864189 1.408 1.90251 … 1.408 0.864189 0.341461 0.0
[:, :, 2] =
0.0 0.0 0.0 0.0 … 0.0 0.0 0.0
0.419066 0.419066 0.419066 0.419066 0.419066 0.419066 0.419066
1.0606 1.0606 1.0606 1.0606 1.0606 1.0606 1.0606
1.728 1.728 1.728 1.728 1.728 1.728 1.728
2.3349 2.3349 2.3349 2.3349 2.3349 2.3349 2.3349
2.82843 2.82843 2.82843 2.82843 … 2.82843 2.82843 2.82843
3.17454 3.17454 3.17454 3.17454 3.17454 3.17454 3.17454
3.35252 3.35252 3.35252 3.35252 3.35252 3.35252 3.35252
3.35252 3.35252 3.35252 3.35252 3.35252 3.35252 3.35252
3.17454 3.17454 3.17454 3.17454 3.17454 3.17454 3.17454
2.82843 2.82843 2.82843 2.82843 … 2.82843 2.82843 2.82843
2.3349 2.3349 2.3349 2.3349 2.3349 2.3349 2.3349
1.728 1.728 1.728 1.728 1.728 1.728 1.728
1.0606 1.0606 1.0606 1.0606 1.0606 1.0606 1.0606
0.419066 0.419066 0.419066 0.419066 0.419066 0.419066 0.419066
0.0 0.0 0.0 0.0 … 0.0 0.0 0.0
Loss Function and Optimization
To train the neural network $\mathcal{N}_\theta(U, V)$ embedded in the UDE, we define a loss function that measures how closely the solution of the UDE matches the ground truth data generated earlier.
The loss is computed as the sum of squared errors between the predicted solution from the UDE and the true solution at each saved time point. If the solver fails (e.g., due to numerical instability or incorrect parameters), we return an infinite loss to discard that configuration during optimization. We use FBDF()
as the solver due to the stiff nature of the brusselators euqation. Other solvers like KenCarp47()
could also be used.
To efficiently compute gradients of the loss with respect to the neural network parameters, we use an adjoint sensitivity method (GaussAdjoint
), which performs high-accuracy quadrature-based integration of the adjoint equations. This approach enables scalable and memory-efficient training for stiff PDEs by avoiding full trajectory storage while maintaining accurate gradient estimates.
The loss function and initial evaluation are implemented as follows:
println("[Loss] Defining loss function...")
function loss_fn(ps, _)
prob = remake(prob_ude_template, p=ps)
sol = solve(prob, FBDF(), saveat=t_points)
# Failed solve
if !SciMLBase.successful_retcode(sol)
return Inf32
end
pred = Array(sol)
lval = sum(abs2, pred .- u_true) / length(u_true)
return lval
end
loss_fn (generic function with 1 method)
Once the loss function is defined, we use the ADAM optimizer to train the neural network. The optimization problem is defined using SciML's Optimization.jl
tools, and gradients are computed via automatic differentiation using AutoZygote()
from SciMLSensitivity
:
println("[Training] Starting optimization...")
using OptimizationOptimisers
optf = OptimizationFunction(loss_fn, AutoZygote())
optprob = OptimizationProblem(optf, ps_init)
loss_history = Float32[]
callback = (ps, l) -> begin
push!(loss_history, l)
println("Epoch $(length(loss_history)): Loss = $l")
false
end
#5 (generic function with 1 method)
Finally to run everything:
res = solve(optprob, Optimisers.Adam(0.01), callback=callback, maxiters=100)
retcode: Default
u: ComponentVector{Float32}(layer_1 = (weight = Float32[-0.57625866 -0.7303992; 0.9511621 0.72524095; … ; 1.4660717 -1.0316354; 0.8443421 0.5216115], bias = Float32[0.33522263, -0.695657, -0.19291273, 0.048209168, -0.042032395, -0.09816901, -0.27681547, 0.09919875, -0.65861756, -0.43743527, -0.33108848, 0.2989078, 0.6823301, 0.871544, -0.072246745, 0.53646594]), layer_2 = (weight = Float32[-0.3881507 -0.109445654 … 0.065602936 0.49085072], bias = Float32[0.19559306]))
res.objective
1.0755547f0
println("[Plot] Final U/V comparison plots...")
center = N_GRID ÷ 2
sol_final = solve(remake(prob_ude_template, p=res.u), FBDF(), saveat=t_points)
pred = Array(sol_final)
p1 = plot(t_points, u_true[center,center,1,:], lw=2, label="U True")
plot!(p1, t_points, pred[center,center,1,:], lw=2, ls=:dash, label="U Pred")
title!(p1, "Center U Concentration Over Time")
p2 = plot(t_points, u_true[center,center,2,:], lw=2, label="V True")
plot!(p2, t_points, pred[center,center,2,:], lw=2, ls=:dash, label="V Pred")
title!(p2, "Center V Concentration Over Time")
plot(p1, p2, layout=(1,2), size=(900,400))
Results and Conclusion
After training the Universal Differential Equation (UDE), we compared the predicted dynamics to the ground truth for both chemical species.
The low training loss shows us that the neural network in the UDE was able to understand the underlying dynamics, and it was able to learn the $U^2V$ term in the partial differential equation.