GPU-Accelerated Physics-Informed Neural Network (PINN) PDE Solvers

Machine learning is all the rage. Everybody thinks physics is cool.

Therefore, using machine learning to solve physics equations? 🧠💥

So let's be cool and use a physics-informed neural network (PINN) to solve the Heat Equation. Let's be even cooler by using GPUs (ironically, creating even more heat, but it's the heat equation so that's cool).

Step 1: Import Libraries

To solve PDEs using neural networks, we will use the NeuralPDE.jl package. This package uses ModelingToolkit's symbolic PDESystem as an input, and it generates an Optimization.jlOptimizationProblem which, when solved, gives the weights of the neural network that solve the PDE. In the end, our neural network NN satisfies the PDE equations and is thus the solution to the PDE! Thus our packages look like:

# High Level Interface
using NeuralPDE
import ModelingToolkit: Interval

# Optimization Libraries
using Optimization, OptimizationOptimisers

# Machine Learning Libraries and Helpers
using Lux, LuxCUDA, ComponentArrays
const gpud = gpu_device() # allocate a GPU device

# Standard Libraries
using Printf, Random

# Plotting
using Plots

Problem Setup

Let's solve the 2+1-dimensional Heat Equation. This is the PDE:

\[∂_t u(x, y, t) = ∂^2_x u(x, y, t) + ∂^2_y u(x, y, t) \, ,\]

with the initial and boundary conditions:

\[\begin{align*} u(x, y, 0) &= e^{x+y} \cos(x + y) \, ,\\ u(0, y, t) &= e^{y} \cos(y + 4t) \, ,\\ u(2, y, t) &= e^{2+y} \cos(2 + y + 4t) \, ,\\ u(x, 0, t) &= e^{x} \cos(x + 4t) \, ,\\ u(x, 2, t) &= e^{x+2} \cos(x + 2 + 4t) \, , \end{align*}\]

on the space and time domain:

\[x \in [0, 2] \, ,\ y \in [0, 2] \, , \ t \in [0, 2] \, ,\]

with physics-informed neural networks.

Step 2: Define the PDESystem

First, let's use ModelingToolkit's PDESystem to represent the PDE. To do this, basically just copy-paste the PDE definition into Julia code. This looks like:

@parameters t x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
Dt = Differential(t)
t_min = 0.0
t_max = 2.0
x_min = 0.0
x_max = 2.0
y_min = 0.0
y_max = 2.0

# 2D PDE
eq = Dt(u(t, x, y)) ~ Dxx(u(t, x, y)) + Dyy(u(t, x, y))

analytic_sol_func(t, x, y) = exp(x + y) * cos(x + y + 4t)
# Initial and boundary conditions
bcs = [u(t_min, x, y) ~ analytic_sol_func(t_min, x, y),
    u(t, x_min, y) ~ analytic_sol_func(t, x_min, y),
    u(t, x_max, y) ~ analytic_sol_func(t, x_max, y),
    u(t, x, y_min) ~ analytic_sol_func(t, x, y_min),
    u(t, x, y_max) ~ analytic_sol_func(t, x, y_max)]

# Space and time domains
domains = [t ∈ Interval(t_min, t_max),
    x ∈ Interval(x_min, x_max),
    y ∈ Interval(y_min, y_max)]

@named pde_system = PDESystem(eq, bcs, domains, [t, x, y], [u(t, x, y)])

\[ \begin{align} \frac{\mathrm{d}}{\mathrm{d}t} u\left( t, x, y \right) &= \frac{\mathrm{d}}{\mathrm{d}y} \frac{\mathrm{d}}{\mathrm{d}y} u\left( t, x, y \right) + \frac{\mathrm{d}}{\mathrm{d}x} \frac{\mathrm{d}}{\mathrm{d}x} u\left( t, x, y \right) \end{align} \]

Note

We used the wildcard form of the variable definition @variables u(..) which then requires that we always specify what the dependent variables of u are. This is because in the boundary conditions we change from using u(t,x,y) to more specific points and lines, like u(t,x_max,y).

Step 3: Define the Lux Neural Network

Now let's define the neural network that will act as our solution. We will use a simple multi-layer perceptron, like:

using Lux
inner = 25
chain = Chain(Dense(3, inner, Lux.σ),
    Dense(inner, inner, Lux.σ),
    Dense(inner, inner, Lux.σ),
    Dense(inner, inner, Lux.σ),
    Dense(inner, 1))
ps = Lux.setup(Random.default_rng(), chain)[1]
(layer_1 = (weight = Float32[-0.6903188 0.79514754 -0.11460304; -0.466838 -0.7268523 -0.35632145; … ; -0.88567305 0.10634673 -0.91026556; 0.06646907 0.013930798 -0.32571578], bias = Float32[0.46405852, -0.18135509, 0.093764715, 0.081647575, 0.57492346, 0.5601754, 0.26832625, -0.5119787, -0.3580014, -0.31997526  …  -0.23054813, 0.39352345, 0.45589766, -0.33663887, 0.34138033, -0.1925687, -0.25254375, -0.27075663, 0.020120244, 0.499262]), layer_2 = (weight = Float32[-0.15906066 -0.24026477 … 0.037109617 -0.105879106; -0.096668474 0.13160707 … 0.21570988 -0.29401216; … ; 0.3463092 -0.34021845 … 0.22604032 0.06053818; 0.24578777 0.23959327 … 0.3292761 -0.0031983217], bias = Float32[-0.15242521, 0.09561825, -0.091798685, 0.18406765, 0.040756725, -0.03185332, 0.14769359, -0.18523452, 0.16114572, -0.09524302  …  -0.18570313, 0.07168987, 0.1257008, -0.10865059, 0.10123327, -0.11458971, 0.04002607, 0.10018561, 0.18717253, 0.16900253]), layer_3 = (weight = Float32[0.19306812 0.23224445 … 0.12362638 0.154606; -0.12869357 -0.022992732 … 0.13724099 0.29222432; … ; -0.11853252 -0.07717172 … 0.23866351 -0.21476518; 0.21906075 -0.01955973 … -0.18527986 -0.3419671], bias = Float32[-0.033008337, 0.14672427, 0.1035264, -0.109085776, 0.087257385, -0.12364986, 0.14893389, 0.05561433, -0.030197406, 0.0011824608  …  -0.1598442, 0.15717137, -0.08592379, -0.03108406, -0.01853621, 0.022322226, -0.018437196, -0.006363583, 0.08921997, 0.03843062]), layer_4 = (weight = Float32[0.0986632 -0.05171457 … -0.14573978 -0.07121252; 0.20316474 0.09014027 … 0.10260397 -0.29423794; … ; -0.1917199 0.10477797 … 0.339968 0.08398962; 0.2903581 -0.30610543 … -0.14330831 -0.06755223], bias = Float32[-0.18967202, 0.014586234, -0.14556313, -0.07086289, 0.13113633, -0.09450338, -0.19988993, 0.16809313, -0.12525296, 0.11378269  …  0.1952701, 0.0784029, -0.02287147, 0.074425265, 0.047714256, 0.155376, -0.05466206, -0.026545454, 0.08372829, 0.056290366]), layer_5 = (weight = Float32[-0.25599185 0.28132337 … 0.015142494 -0.017154485], bias = Float32[-0.12392847]))

Step 4: Place it on the GPU.

Just plop it on that sucker. We must ensure that our initial parameters for the neural network are on the GPU. If that is done, then the internal computations will all take place on the GPU. This is done by using the gpud function (i.e. the GPU device we created at the start) on the initial parameters, like:

ps = ps |> ComponentArray |> gpud .|> Float64
ComponentArrays.ComponentVector{Float64, CUDA.CuArray{Float64, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:100, Axis(weight = ViewAxis(1:75, ShapedAxis((25, 3))), bias = ViewAxis(76:100, Shaped1DAxis((25,))))), layer_2 = ViewAxis(101:750, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_3 = ViewAxis(751:1400, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_4 = ViewAxis(1401:2050, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_5 = ViewAxis(2051:2076, Axis(weight = ViewAxis(1:25, ShapedAxis((1, 25))), bias = ViewAxis(26:26, Shaped1DAxis((1,))))))}}}(layer_1 = (weight = [-0.6903188228607178 0.7951475381851196 -0.11460304260253906; -0.466838002204895 -0.726852297782898 -0.35632145404815674; … ; -0.8856730461120605 0.1063467264175415 -0.9102655649185181; 0.06646907329559326 0.013930797576904297 -0.3257157802581787], bias = [0.464058518409729, -0.1813550889492035, 0.09376471489667892, 0.08164757490158081, 0.5749234557151794, 0.5601754188537598, 0.2683262526988983, -0.5119786858558655, -0.3580014109611511, -0.31997525691986084  …  -0.23054812848567963, 0.3935234546661377, 0.45589765906333923, -0.336638867855072, 0.3413803279399872, -0.19256870448589325, -0.25254374742507935, -0.27075663208961487, 0.020120244473218918, 0.4992620050907135]), layer_2 = (weight = [-0.15906065702438354 -0.24026477336883545 … 0.0371096171438694 -0.10587910562753677; -0.09666847437620163 0.1316070705652237 … 0.2157098799943924 -0.2940121591091156; … ; 0.34630921483039856 -0.3402184545993805 … 0.226040318608284 0.06053818017244339; 0.24578776955604553 0.2395932674407959 … 0.32927611470222473 -0.0031983216758817434], bias = [-0.15242521464824677, 0.09561824798583984, -0.09179868549108505, 0.18406765162944794, 0.040756724774837494, -0.031853318214416504, 0.14769358932971954, -0.18523451685905457, 0.16114571690559387, -0.09524302184581757  …  -0.1857031285762787, 0.07168986648321152, 0.12570080161094666, -0.1086505874991417, 0.10123326629400253, -0.11458971351385117, 0.040026068687438965, 0.10018561035394669, 0.187172532081604, 0.16900253295898438]), layer_3 = (weight = [0.19306811690330505 0.23224444687366486 … 0.12362638115882874 0.15460599958896637; -0.1286935657262802 -0.022992732003331184 … 0.13724099099636078 0.29222431778907776; … ; -0.11853252351284027 -0.07717172056436539 … 0.23866350948810577 -0.21476517617702484; 0.2190607488155365 -0.01955972984433174 … -0.18527986109256744 -0.3419671058654785], bias = [-0.03300833702087402, 0.14672426879405975, 0.10352639853954315, -0.10908577591180801, 0.08725738525390625, -0.12364985793828964, 0.14893388748168945, 0.05561432987451553, -0.03019740618765354, 0.0011824608081951737  …  -0.15984420478343964, 0.157171368598938, -0.08592379093170166, -0.031084060668945312, -0.01853621006011963, 0.02232222631573677, -0.01843719556927681, -0.006363582797348499, 0.08921997249126434, 0.03843061998486519]), layer_4 = (weight = [0.09866320341825485 -0.051714569330215454 … -0.14573977887630463 -0.0712125226855278; 0.203164741396904 0.09014026820659637 … 0.1026039719581604 -0.2942379415035248; … ; -0.1917199045419693 0.10477796941995621 … 0.3399679958820343 0.08398962020874023; 0.2903580963611603 -0.30610543489456177 … -0.14330831170082092 -0.06755223125219345], bias = [-0.18967202305793762, 0.014586234465241432, -0.14556312561035156, -0.07086288928985596, 0.13113632798194885, -0.09450338035821915, -0.1998899281024933, 0.16809312999248505, -0.12525296211242676, 0.11378268897533417  …  0.1952701061964035, 0.07840289920568466, -0.02287147007882595, 0.07442526519298553, 0.04771425575017929, 0.15537600219249725, -0.054662059992551804, -0.026545453816652298, 0.08372829109430313, 0.056290365755558014]), layer_5 = (weight = [-0.2559918463230133 0.2813233733177185 … 0.01514249388128519 -0.01715448498725891], bias = [-0.12392847239971161]))

Step 5: Discretize the PDE via a PINN Training Strategy

strategy = GridTraining(0.05)
discretization = PhysicsInformedNN(chain,
    strategy,
    init_params = ps)
prob = discretize(pde_system, discretization)
OptimizationProblem. In-place: true
u0: ComponentArrays.ComponentVector{Float64, CUDA.CuArray{Float64, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:100, Axis(weight = ViewAxis(1:75, ShapedAxis((25, 3))), bias = ViewAxis(76:100, Shaped1DAxis((25,))))), layer_2 = ViewAxis(101:750, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_3 = ViewAxis(751:1400, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_4 = ViewAxis(1401:2050, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_5 = ViewAxis(2051:2076, Axis(weight = ViewAxis(1:25, ShapedAxis((1, 25))), bias = ViewAxis(26:26, Shaped1DAxis((1,))))))}}}(layer_1 = (weight = [-0.6903188228607178 0.7951475381851196 -0.11460304260253906; -0.466838002204895 -0.726852297782898 -0.35632145404815674; … ; -0.8856730461120605 0.1063467264175415 -0.9102655649185181; 0.06646907329559326 0.013930797576904297 -0.3257157802581787], bias = [0.464058518409729, -0.1813550889492035, 0.09376471489667892, 0.08164757490158081, 0.5749234557151794, 0.5601754188537598, 0.2683262526988983, -0.5119786858558655, -0.3580014109611511, -0.31997525691986084  …  -0.23054812848567963, 0.3935234546661377, 0.45589765906333923, -0.336638867855072, 0.3413803279399872, -0.19256870448589325, -0.25254374742507935, -0.27075663208961487, 0.020120244473218918, 0.4992620050907135]), layer_2 = (weight = [-0.15906065702438354 -0.24026477336883545 … 0.0371096171438694 -0.10587910562753677; -0.09666847437620163 0.1316070705652237 … 0.2157098799943924 -0.2940121591091156; … ; 0.34630921483039856 -0.3402184545993805 … 0.226040318608284 0.06053818017244339; 0.24578776955604553 0.2395932674407959 … 0.32927611470222473 -0.0031983216758817434], bias = [-0.15242521464824677, 0.09561824798583984, -0.09179868549108505, 0.18406765162944794, 0.040756724774837494, -0.031853318214416504, 0.14769358932971954, -0.18523451685905457, 0.16114571690559387, -0.09524302184581757  …  -0.1857031285762787, 0.07168986648321152, 0.12570080161094666, -0.1086505874991417, 0.10123326629400253, -0.11458971351385117, 0.040026068687438965, 0.10018561035394669, 0.187172532081604, 0.16900253295898438]), layer_3 = (weight = [0.19306811690330505 0.23224444687366486 … 0.12362638115882874 0.15460599958896637; -0.1286935657262802 -0.022992732003331184 … 0.13724099099636078 0.29222431778907776; … ; -0.11853252351284027 -0.07717172056436539 … 0.23866350948810577 -0.21476517617702484; 0.2190607488155365 -0.01955972984433174 … -0.18527986109256744 -0.3419671058654785], bias = [-0.03300833702087402, 0.14672426879405975, 0.10352639853954315, -0.10908577591180801, 0.08725738525390625, -0.12364985793828964, 0.14893388748168945, 0.05561432987451553, -0.03019740618765354, 0.0011824608081951737  …  -0.15984420478343964, 0.157171368598938, -0.08592379093170166, -0.031084060668945312, -0.01853621006011963, 0.02232222631573677, -0.01843719556927681, -0.006363582797348499, 0.08921997249126434, 0.03843061998486519]), layer_4 = (weight = [0.09866320341825485 -0.051714569330215454 … -0.14573977887630463 -0.0712125226855278; 0.203164741396904 0.09014026820659637 … 0.1026039719581604 -0.2942379415035248; … ; -0.1917199045419693 0.10477796941995621 … 0.3399679958820343 0.08398962020874023; 0.2903580963611603 -0.30610543489456177 … -0.14330831170082092 -0.06755223125219345], bias = [-0.18967202305793762, 0.014586234465241432, -0.14556312561035156, -0.07086288928985596, 0.13113632798194885, -0.09450338035821915, -0.1998899281024933, 0.16809312999248505, -0.12525296211242676, 0.11378268897533417  …  0.1952701061964035, 0.07840289920568466, -0.02287147007882595, 0.07442526519298553, 0.04771425575017929, 0.15537600219249725, -0.054662059992551804, -0.026545453816652298, 0.08372829109430313, 0.056290365755558014]), layer_5 = (weight = [-0.2559918463230133 0.2813233733177185 … 0.01514249388128519 -0.01715448498725891], bias = [-0.12392847239971161]))

Step 6: Solve the Optimization Problem

callback = function (state, l)
    println("Current loss is: $l")
    return false
end

res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 2500);
retcode: Default
u: ComponentArrays.ComponentVector{Float64, CUDA.CuArray{Float64, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:100, Axis(weight = ViewAxis(1:75, ShapedAxis((25, 3))), bias = ViewAxis(76:100, Shaped1DAxis((25,))))), layer_2 = ViewAxis(101:750, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_3 = ViewAxis(751:1400, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_4 = ViewAxis(1401:2050, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_5 = ViewAxis(2051:2076, Axis(weight = ViewAxis(1:25, ShapedAxis((1, 25))), bias = ViewAxis(26:26, Shaped1DAxis((1,))))))}}}(layer_1 = (weight = [2.506521508048293 1.6633246713211505 -0.4222741370049736; -1.8799729282789615 -0.8747717565851855 -0.908675315217979; … ; -3.073034514292949 -0.48661638851250394 -0.8955535760585882; 2.6745491709990223 0.5059798490119995 0.36101099002099984], bias = [-0.42607917057601097, 2.3222334118041, 1.8341817401097364, 1.637037917945057, -1.4999203646833452, 2.128696372094906, 0.4363276646192137, 1.1618177728284051, -1.6058860166433555, 1.2068852285676424  …  3.449812165130257, -0.01873892005530371, -0.8455344301518067, -5.836849291334669, 0.24289830117958722, -0.7858718675364468, -1.6532892428098456, 0.2023985755574849, 0.023926881409440492, -0.3745634533334107]), layer_2 = (weight = [0.056643276250786606 0.3987996660884881 … 0.8447075508945427 -0.35245270918081; 0.40140481422428165 0.30081558750194004 … 0.4670349459121454 -0.7508079900065865; … ; 0.5049122269202104 -1.029803292070657 … -0.3366256477994957 -0.3469522240781674; 0.23201025719933963 1.3336250956608249 … 0.41304375450468245 -0.44599711536608644], bias = [0.005208035405864547, 0.25311335590147394, -0.03643397336872087, 0.29536226967182483, 0.7066620237773261, 0.05458260001254932, 0.26501704153313027, -0.06776205003274242, 0.16872105537728974, -0.12094436905189183  …  -0.2514237754925268, 0.07233694585507262, 0.15877034663234904, -0.00933609554299057, 0.16009063123602224, -0.009138869503712871, 0.31434647679092687, -0.18304252628075024, 0.17370025736735906, 0.41736293095154886]), layer_3 = (weight = [1.0283931587512156 1.3054774912510207 … 0.9910828188684867 0.5288140384973912; 0.6704837489069857 0.835533659529077 … 1.1637095471351562 -1.7178384406962781; … ; 1.0644327671036589 1.0015302221141102 … 0.3677927721424758 1.7031112795362955; 1.3055053794959692 1.0618383505527733 … 0.2792597058391831 -0.6683904112420729], bias = [-0.6737018580966133, -0.22852032647119877, -0.027115256237340343, -0.4448282575147011, -0.27823742578705807, 0.0675537925751598, 0.4431889185923137, 0.6100741878563094, -0.6010776571551012, 0.45839160082845637  …  -0.5219164499526282, -0.06660376070234715, 0.1266469223890267, 0.2891757436309735, 0.01989558533050158, 0.7211850555421602, -0.5757079249685358, -0.730368372297502, 0.16093916577597173, -0.22532398747700572]), layer_4 = (weight = [0.7681721042488076 1.5000315311718797 … -0.6027490651625577 -0.2608733509921965; -0.30904664794566106 -1.3119921409886104 … -0.42632268770919685 -0.9959415454766334; … ; 1.2973590507931994 1.4788221565220483 … -0.6636225168814196 0.703712366852898; 2.4889353411148445 2.1339227080447607 … -3.50564644761288 -0.5137890792484748], bias = [-1.3452604135042867, 1.0559202159218923, -1.1243170771830486, -0.8455607672194895, -0.9915760140595556, 1.1646456154297122, -1.6650258703692187, -1.2150053553042317, 1.0528065609856199, 1.2052049868445507  …  -1.0656944303377507, -0.9146188404465624, 1.5469804814122976, 1.2525953958630367, 1.0303178798998778, 1.3862628406595594, -1.1816479231099313, -0.7744540391408445, -0.666855644011258, -1.149492516928581]), layer_5 = (weight = [-4.459464643100861 7.009291164295092 … -9.818905039206406 -9.867516638139653], bias = [1.6060424241424516]))

We then use the remake function to rebuild the PDE problem to start a new optimization at the optimized parameters, and continue with a lower learning rate:

prob = remake(prob, u0 = res.u)
res = Optimization.solve(prob, Adam(0.001); callback = callback, maxiters = 2500);
retcode: Default
u: ComponentArrays.ComponentVector{Float64, CUDA.CuArray{Float64, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:100, Axis(weight = ViewAxis(1:75, ShapedAxis((25, 3))), bias = ViewAxis(76:100, Shaped1DAxis((25,))))), layer_2 = ViewAxis(101:750, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_3 = ViewAxis(751:1400, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_4 = ViewAxis(1401:2050, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_5 = ViewAxis(2051:2076, Axis(weight = ViewAxis(1:25, ShapedAxis((1, 25))), bias = ViewAxis(26:26, Shaped1DAxis((1,))))))}}}(layer_1 = (weight = [2.4603143336322044 1.7111255728293688 -0.28779197725422956; -1.7692661877206952 -0.9198228893859373 -0.9513816182425887; … ; -3.703228737310451 -0.48100723900663955 -0.5834912300018112; 2.7258003962646153 0.49084064211274103 0.372283166213619], bias = [-0.37400646000040466, 2.328050841970153, 1.8538801548226662, 1.6328721541858082, -1.3649770575639222, 2.116734648531063, 0.41763815935630116, 1.1441602679679395, -1.623141212283536, 1.217249919850323  …  3.46667628550627, -0.05642295739116937, -0.8566607473531279, -5.865727346467786, 0.236369858253845, -0.8137856458035547, -1.6503492974930047, 0.2315532501057496, -0.0068035348007924705, -0.3850515548320525]), layer_2 = (weight = [0.06887559334519609 0.43671427141061964 … 0.8967471452223529 -0.3429785495993334; 0.41904914837321405 0.3378007336246953 … 0.548671472778031 -0.7379035762085092; … ; 0.5062323777049853 -1.0794502847201473 … -0.29536237173057667 -0.34160107754569974; 0.20200080318932498 1.2807388610281045 … 0.4477740891877175 -0.4517166203625615], bias = [0.015276215357998921, 0.26631671204172636, -0.03710164622899752, 0.3130750840227417, 0.7178548285048092, 0.06440383516052985, 0.2967310056242212, -0.06803987995860594, 0.1671468298983244, -0.11668565950624199  …  -0.24764164188167234, 0.0815440558367783, 0.1679410603100939, -0.004205279018516243, 0.15311916005591353, -0.014616224755750718, 0.32995084661124946, -0.16140158287911766, 0.1786859697600292, 0.4170190331483799]), layer_3 = (weight = [1.019712748447704 1.2975780404852446 … 0.9849452395245045 0.515769070849712; 0.6689736716632362 0.8307087935805936 … 1.1963752980224687 -1.772515207943337; … ; 1.1446557355539673 1.1004795435829806 … 0.4480478088601841 1.7024943985659182; 1.3345027088296009 1.0870190329571723 … 0.2794714403459085 -0.6969292298897928], bias = [-0.6910780317762435, -0.2271526985192966, -0.028848353244584214, -0.4289614644005328, -0.2731765305407612, 0.08214301165814315, 0.4168556707881722, 0.5615659447482632, -0.5917703422490197, 0.4662384124404245  …  -0.5249397971996284, -0.05528989441474002, 0.10480473373767356, 0.3184151529046748, 0.015990434755466656, 0.7281166560230014, -0.5782617888000243, -0.7335319051304147, 0.19571961028915436, -0.1996684621906961]), layer_4 = (weight = [0.8234116948349653 1.540883622007247 … -0.5895823851029723 -0.23033918171440168; -0.38185344758331424 -1.2622980821357361 … -0.40142723495395055 -0.9447942550754298; … ; 1.2132135781183286 1.3988357055597085 … -0.6804722341568431 0.6400759928163112; 2.365718380643943 2.239956294239933 … -3.2039463968176816 -0.21662004018402886], bias = [-1.261955610517617, 1.107089052199967, -1.1631404353061792, -0.9106886537722676, -0.8787630415961379, 1.2137364733995115, -1.6717838942607144, -1.25323267236236, 1.1305015618634648, 1.2311647651845101  …  -1.1101340674451239, -1.0029097829614062, 1.691920304428994, 1.3458777908681223, 1.071715414170936, 1.406096983356908, -1.1822842990307414, -0.7915291917473402, -0.5953103389442524, -1.425731086649025]), layer_5 = (weight = [-4.946162406167682 7.563587953268755 … -10.224423502649884 -10.544072432618696], bias = [1.7889324742370702]))

Step 7: Inspect the PINN's Solution

Finally, we inspect the solution:

phi = discretization.phi
ts, xs, ys = [infimum(d.domain):0.1:supremum(d.domain) for d in domains]
u_real = [analytic_sol_func(t, x, y) for t in ts for x in xs for y in ys]
u_predict = [first(Array(phi([t, x, y], res.u))) for t in ts for x in xs for y in ys]

function plot_(res)
    # Animate
    anim = @animate for (i, t) in enumerate(0:0.05:t_max)
        @info "Animating frame $i..."
        u_real = reshape([analytic_sol_func(t, x, y) for x in xs for y in ys],
            (length(xs), length(ys)))
        u_predict = reshape([Array(phi([t, x, y], res.u))[1] for x in xs for y in ys],
            length(xs), length(ys))
        u_error = abs.(u_predict .- u_real)
        title = @sprintf("predict, t = %.3f", t)
        p1 = plot(xs, ys, u_predict, st = :surface, label = "", title = title)
        title = @sprintf("real")
        p2 = plot(xs, ys, u_real, st = :surface, label = "", title = title)
        title = @sprintf("error")
        p3 = plot(xs, ys, u_error, st = :contourf, label = "", title = title)
        plot(p1, p2, p3)
    end
    gif(anim, "3pde.gif", fps = 10)
end

plot_(res)

3pde