Using GPUs to train Physics-Informed Neural Networks (PINNs)

the 2-dimensional PDE:

\[∂_t u(x, y, t) = ∂^2_x u(x, y, t) + ∂^2_y u(x, y, t) \, ,\]

with the initial and boundary conditions:

\[\begin{align*} u(x, y, 0) &= e^{x+y} \cos(x + y) \, ,\\ u(0, y, t) &= e^{y} \cos(y + 4t) \, ,\\ u(2, y, t) &= e^{2+y} \cos(2 + y + 4t) \, ,\\ u(x, 0, t) &= e^{x} \cos(x + 4t) \, ,\\ u(x, 2, t) &= e^{x+2} \cos(x + 2 + 4t) \, , \end{align*}\]

on the space and time domain:

\[x \in [0, 2] \, ,\ y \in [0, 2] \, , \ t \in [0, 2] \, ,\]

with physics-informed neural networks. The only major difference from the CPU case is that we must ensure that our initial parameters for the neural network are on the GPU. If that is done, then the internal computations will all take place on the GPU. This is done by using the gpu function on the initial parameters, like:

using Lux, LuxCUDA, ComponentArrays, Random
const gpud = gpu_device()
inner = 25
chain = Chain(Dense(3, inner, σ), Dense(inner, inner, σ), Dense(inner, inner, σ),
    Dense(inner, inner, σ), Dense(inner, 1))
ps = Lux.setup(Random.default_rng(), chain)[1]
ps = ps |> ComponentArray |> gpud .|> Float64
ComponentArrays.ComponentVector{Float64, CUDA.CuArray{Float64, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:100, Axis(weight = ViewAxis(1:75, ShapedAxis((25, 3))), bias = ViewAxis(76:100, Shaped1DAxis((25,))))), layer_2 = ViewAxis(101:750, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_3 = ViewAxis(751:1400, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_4 = ViewAxis(1401:2050, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_5 = ViewAxis(2051:2076, Axis(weight = ViewAxis(1:25, ShapedAxis((1, 25))), bias = ViewAxis(26:26, Shaped1DAxis((1,))))))}}}(layer_1 = (weight = [-0.12221157550811768 0.810886025428772 -0.5002186298370361; 0.5314371585845947 -0.48019957542419434 -0.4825037717819214; … ; -0.09392106533050537 -0.028973937034606934 -0.7937654256820679; 0.5340123176574707 -0.776155948638916 -0.34792816638946533], bias = [0.1029156893491745, -0.22234055399894714, -0.24669425189495087, 0.10938508063554764, 0.16666463017463684, 0.35725611448287964, -0.19878405332565308, 0.2093210369348526, 0.20622526109218597, 0.1277180016040802  …  0.45815059542655945, 0.14651155471801758, 0.09350965172052383, -0.5430880188941956, 0.20395560562610626, 0.04911664128303528, -0.2355680614709854, -0.4949522912502289, -0.4766208231449127, 0.4032031297683716]), layer_2 = (weight = [-0.021043015643954277 -0.14205197989940643 … -0.20949304103851318 -0.2098298817873001; -0.06963144987821579 0.08906856924295425 … 0.3319961130619049 0.004880775231868029; … ; 0.3272695243358612 0.23696300387382507 … -0.016014238819479942 0.026514478027820587; 0.2691066563129425 0.06558004766702652 … 0.18279241025447845 0.17584921419620514], bias = [-0.13597726821899414, -0.12077474594116211, 0.10372445732355118, 0.003236103104427457, -0.140625, -0.04883148521184921, 0.16426685452461243, 0.17942652106285095, 0.14047661423683167, -0.08778612315654755  …  0.16337163746356964, 0.03122386895120144, -0.045902468264102936, -0.036893248558044434, 0.16924889385700226, 0.09461138397455215, -0.014641332440078259, 0.08674976974725723, -0.08725521713495255, 0.15604960918426514]), layer_3 = (weight = [0.26925763487815857 -0.11058177798986435 … 0.20406873524188995 -0.25158873200416565; -0.24220532178878784 -0.29263511300086975 … 0.1317291408777237 -0.07601112127304077; … ; 0.2776678800582886 -0.15888279676437378 … -0.1356050819158554 0.13153290748596191; 0.3416121304035187 -0.31589534878730774 … -0.1893266886472702 -0.2576970160007477], bias = [0.036508940160274506, -0.14137740433216095, -0.08296141773462296, 0.12160363048315048, 0.08739133179187775, 0.1073969379067421, -0.13301138579845428, 0.05381450802087784, -0.07507822662591934, -0.024220824241638184  …  0.11908197402954102, 0.02769017219543457, -0.19187000393867493, -0.16232354938983917, 0.02704031392931938, -0.13696317374706268, -0.0658014565706253, -0.12992770969867706, -0.0885612964630127, 0.07050714641809464]), layer_4 = (weight = [-0.032540950924158096 0.18384192883968353 … 0.2040432095527649 0.26774197816848755; 0.060162968933582306 -0.030911149457097054 … -0.18968012928962708 0.06002764403820038; … ; 0.02084203064441681 0.120293028652668 … 0.3305087089538574 0.2875485122203827; 0.0885620042681694 0.2799566984176636 … 0.13860300183296204 0.11314369738101959], bias = [-0.03083052672445774, -0.08365128189325333, 0.15019245445728302, -0.10261025279760361, 0.09055028110742569, 0.05238983780145645, 0.16071076691150665, 0.15984508395195007, -0.017607808113098145, 0.18246038258075714  …  0.15882141888141632, 0.13796135783195496, 0.15853814780712128, -0.1939467191696167, 0.16917447745800018, -0.024460244923830032, 0.08623389899730682, 0.03236403316259384, -0.19109752774238586, -0.1444942057132721]), layer_5 = (weight = [0.06614038348197937 -0.0829191654920578 … -0.0736437439918518 0.2755413353443146], bias = [0.1905532330274582]))

In total, this looks like:

using NeuralPDE, Lux, LuxCUDA, Random, ComponentArrays
using Optimization
using OptimizationOptimisers
import ModelingToolkit: Interval
using Plots
using Printf

@parameters t x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
Dt = Differential(t)
t_min = 0.0
t_max = 2.0
x_min = 0.0
x_max = 2.0
y_min = 0.0
y_max = 2.0

# 2D PDE
eq = Dt(u(t, x, y)) ~ Dxx(u(t, x, y)) + Dyy(u(t, x, y))

analytic_sol_func(t, x, y) = exp(x + y) * cos(x + y + 4t)
# Initial and boundary conditions
bcs = [u(t_min, x, y) ~ analytic_sol_func(t_min, x, y),
    u(t, x_min, y) ~ analytic_sol_func(t, x_min, y),
    u(t, x_max, y) ~ analytic_sol_func(t, x_max, y),
    u(t, x, y_min) ~ analytic_sol_func(t, x, y_min),
    u(t, x, y_max) ~ analytic_sol_func(t, x, y_max)]

# Space and time domains
domains = [t ∈ Interval(t_min, t_max),
    x ∈ Interval(x_min, x_max),
    y ∈ Interval(y_min, y_max)]

# Neural network
inner = 25
chain = Chain(Dense(3, inner, σ), Dense(inner, inner, σ), Dense(inner, inner, σ),
    Dense(inner, inner, σ), Dense(inner, 1))

strategy = QuasiRandomTraining(100)
ps = Lux.setup(Random.default_rng(), chain)[1]
ps = ps |> ComponentArray |> gpud .|> Float64
discretization = PhysicsInformedNN(chain, strategy; init_params = ps)

@named pde_system = PDESystem(eq, bcs, domains, [t, x, y], [u(t, x, y)])
prob = discretize(pde_system, discretization)
symprob = symbolic_discretize(pde_system, discretization)

callback = function (p, l)
    println("Current loss is: $l")
    return false
end

res = Optimization.solve(prob, OptimizationOptimisers.Adam(1e-2); maxiters = 2500)
retcode: Default
u: ComponentArrays.ComponentVector{Float64, CUDA.CuArray{Float64, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:100, Axis(weight = ViewAxis(1:75, ShapedAxis((25, 3))), bias = ViewAxis(76:100, Shaped1DAxis((25,))))), layer_2 = ViewAxis(101:750, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_3 = ViewAxis(751:1400, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_4 = ViewAxis(1401:2050, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_5 = ViewAxis(2051:2076, Axis(weight = ViewAxis(1:25, ShapedAxis((1, 25))), bias = ViewAxis(26:26, Shaped1DAxis((1,))))))}}}(layer_1 = (weight = [-2.4555198760715493 -0.3208537654482114 -0.6472302494783468; 2.8395630657746147 -0.20100780467964927 -0.25702810343246213; … ; -0.16754470140656055 -0.6224246202570061 -0.785917755971478; 0.6150452423756525 0.5128739530447785 0.9055092026189303], bias = [1.8535985427523332, -0.80333150699978, -1.2226805455553962, 1.1421563518677176, -1.022252587093047, 0.08531348967037693, -0.1761291902685652, -0.4292832156569261, 0.8659209181636918, 0.21583826921265645  …  0.9356251913005887, -0.12146552311151405, 1.6938744702786088, 0.2393962505652933, -0.223642131056182, -1.9975218283122451, 2.2344579219845944, -0.22261882170571345, 1.2719629073621679, -0.48068692882249753]), layer_2 = (weight = [0.7054898766752606 1.349225672850836 … 0.953419548188988 -0.12195928877266013; -1.8000373594574846 1.5691868564605678 … -1.011766137664313 0.28032356921183815; … ; 0.593778095227491 -1.2883086030711535 … 0.5048979110860674 -0.22822479133608692; -0.587753378085515 1.317551222359984 … -0.3555271994812263 -0.05825176212346655], bias = [-0.03258872047786697, -0.2729008161109636, -0.14369634294471026, -0.25304883771738057, 0.00870630698010068, 0.10422237362728771, 0.12820033160396857, 0.09391092250892938, -0.10034852224115169, 0.06417487745462618  …  -0.11692750053529898, 0.21408524580352992, 0.04247045675155099, 0.2721455721220609, -0.17842565637302776, 0.08852976765577934, -0.3161053990619053, -0.2058775365597399, 0.03799556248552664, 0.08932065475286859]), layer_3 = (weight = [-0.24229428681841425 1.8681740873071353 … -0.36320166954368466 -0.08013380757030097; 0.10912664219016888 1.514301391721449 … -1.1394497271363482 0.8749613585464824; … ; 0.5662540772065138 1.1319267017630146 … -1.3701437607537983 0.9870755646503678; 0.4044904687015964 0.6608855360262559 … -0.7065194527600953 0.4955141716429862], bias = [-0.5387772190312768, -0.39015956926870704, -0.5448768653534214, -0.25184783261353366, -0.4165643930620766, -0.42461490995341605, -0.5389085358839256, 0.042086721622183264, -0.0199568027723217, -0.4671249382073516  …  -0.5665969955411416, -0.4027409071671327, -0.3552020011303529, -0.09095473800882166, 0.1865134103176178, -0.4920373278238549, -0.061709406152768605, 0.26993837666214565, -0.3134159627509064, -0.4407026187713652]), layer_4 = (weight = [-0.3217081610082492 0.9349339967331406 … 1.665511829256642 0.8340601503331361; -1.2952709813502976 1.1175329060158326 … 1.4365755685971857 0.4984412154021945; … ; 1.6659734106935813 -0.7072828129765825 … -0.46378773881893803 -0.2753515564936116; 1.810129741092074 -0.9455383923770552 … -0.825011208936102 -0.785576731387161], bias = [-1.108110888734632, -0.9166756074785622, 0.36987607558118807, -0.9076060524205912, -0.3888960169858128, -0.8939420734112137, 0.28623035657195744, -1.2134773375071886, -0.7201621789195266, -0.873444248686783  …  -1.1976240538935965, -0.005386210330323067, 0.8774713307409033, -0.47064345878518377, 0.5371439384722049, 0.891146497538885, -0.19522012729670601, -0.003970856899505703, 0.3262017948013947, 0.2991306861226128]), layer_5 = (weight = [-2.9688275632667804 -2.945842753495302 … 5.340774659224491 4.781902561439397], bias = [0.5786752490704657]))

We then use the remake function to rebuild the PDE problem to start a new optimization at the optimized parameters, and continue with a lower learning rate:

prob = remake(prob, u0 = res.u)
res = Optimization.solve(
    prob, OptimizationOptimisers.Adam(1e-3); callback = callback, maxiters = 2500)
retcode: Default
u: ComponentArrays.ComponentVector{Float64, CUDA.CuArray{Float64, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:100, Axis(weight = ViewAxis(1:75, ShapedAxis((25, 3))), bias = ViewAxis(76:100, Shaped1DAxis((25,))))), layer_2 = ViewAxis(101:750, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_3 = ViewAxis(751:1400, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_4 = ViewAxis(1401:2050, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_5 = ViewAxis(2051:2076, Axis(weight = ViewAxis(1:25, ShapedAxis((1, 25))), bias = ViewAxis(26:26, Shaped1DAxis((1,))))))}}}(layer_1 = (weight = [-2.450359084792139 -0.3214554693857005 -0.6255953911805228; 2.967721618880841 -0.19258848757452604 -0.26275614469260583; … ; -0.17398161217766808 -0.614486874300166 -0.790587692247771; 0.5342161896035923 0.5627360517645023 1.0145484258802793], bias = [1.8811229474716644, -0.7917536618649508, -1.217592954185055, 1.1422790132702467, -1.032401029972203, 0.11823350282938956, -0.12042708746344358, -0.45679511003428613, 0.9071757349297498, 0.25685132612537653  …  0.9697024733015664, -0.10818508830986802, 1.6449658270549354, 0.2570589716393586, -0.23183337208362842, -2.0187241426494267, 2.2526834975541474, -0.22491929142083705, 1.2780226818634757, -0.5103075684756769]), layer_2 = (weight = [0.7069444159762082 1.3480165069157504 … 1.000387717918478 -0.14423812247235934; -1.8307354664079996 1.5910709693810843 … -0.8520654429226006 0.28739563002034113; … ; 0.5900013153974818 -1.2896177285222474 … 0.5260863680538456 -0.23473165492198939; -0.5917752330656878 1.3198371426124933 … -0.29809397398242354 -0.06406971374544078], bias = [-0.05005674278985161, -0.25815356889376795, -0.14143098976368196, -0.2753012641364959, 0.013335216665706112, 0.10560542283207326, 0.13256010461603157, 0.11384480387996875, -0.10784342633421311, 0.07247097958619719  …  -0.1057480805770355, 0.22090426244656405, 0.03012131781030689, 0.27845582057396884, -0.15968386182131528, 0.11775268656992231, -0.3113104832290314, -0.2079741287722381, 0.03864535722117362, 0.0855188709918359]), layer_3 = (weight = [-0.2572388150886206 1.8219076667042422 … -0.37571974321954327 -0.1010805047830326; 0.11111190912191618 1.5722527326777778 … -1.1319059067736992 0.8757102184962144; … ; 0.5617633786200386 1.1435330360652063 … -1.3973584208838734 0.9818521370722159; 0.4252709038930123 0.7145698681222492 … -0.6723766081002068 0.4949445207497976], bias = [-0.5464463661268093, -0.3954943630637166, -0.5621486327188387, -0.26202484128549447, -0.42725965428130935, -0.4454180801211121, -0.5528988349944843, 0.04842011044000886, -0.007284989019237515, -0.4730892818343853  …  -0.5686198970721209, -0.40986399979326327, -0.35025510692862455, -0.07845749367464033, 0.1819092968347859, -0.4803120974374525, -0.04699020241253139, 0.2710173170665194, -0.32990736365159856, -0.4255830610501203]), layer_4 = (weight = [0.02186770422303711 0.9573558665798162 … 1.6858003721385333 0.7886307489169994; -1.3543921685170308 1.168443754744903 … 1.5062072586426012 0.5204989340674673; … ; 1.723686191035364 -0.7332388150023325 … -0.4733846912738515 -0.3129500210392682; 1.8955628391122998 -0.9837085789709678 … -0.8675157202422153 -0.8302666982125587], bias = [-1.2014688744236064, -0.9691289811949092, 0.3945622780879151, -0.9677703134277645, -0.3806111127525514, -0.8939241175363014, 0.32289334545818177, -1.2353284625279826, -0.7610359040858512, -0.9298442596512811  …  -1.256718329308837, -0.03516080790950072, 0.8727585028915017, -0.4466467357831258, 0.5394060219066691, 0.8752411569132673, -0.22071092858577526, -0.02668074139911092, 0.3034952128496901, 0.29869380440663057]), layer_5 = (weight = [-3.190892540680895 -3.1570083242738862 … 5.8712772436574845 5.347478342276666], bias = [0.6423947499595629]))

Finally, we inspect the solution:

phi = discretization.phi
ts, xs, ys = [infimum(d.domain):0.1:supremum(d.domain) for d in domains]
u_real = [analytic_sol_func(t, x, y) for t in ts for x in xs for y in ys]
u_predict = [first(Array(phi([t, x, y], res.u))) for t in ts for x in xs for y in ys]

function plot_(res)
    # Animate
    anim = @animate for (i, t) in enumerate(0:0.05:t_max)
        @info "Animating frame $i..."
        u_real = reshape([analytic_sol_func(t, x, y) for x in xs for y in ys],
            (length(xs), length(ys)))
        u_predict = reshape([Array(phi([t, x, y], res.u))[1] for x in xs for y in ys],
            length(xs), length(ys))
        u_error = abs.(u_predict .- u_real)
        title = @sprintf("predict, t = %.3f", t)
        p1 = plot(xs, ys, u_predict, st = :surface, label = "", title = title)
        title = @sprintf("real")
        p2 = plot(xs, ys, u_real, st = :surface, label = "", title = title)
        title = @sprintf("error")
        p3 = plot(xs, ys, u_error, st = :contourf, label = "", title = title)
        plot(p1, p2, p3)
    end
    gif(anim, "3pde.gif", fps = 10)
end

plot_(res)
Example block output

Performance benchmarks

Here are some performance benchmarks for 2d-pde with various number of input points and the number of neurons in the hidden layer, measuring the time for 100 iterations. Comparing runtime with GPU and CPU.

julia> CUDA.device()

image