Imposing Constraints on Physics-Informed Neural Network (PINN) Solutions
Let's consider the Fokker-Planck equation:
\[- \frac{∂}{∂x} \left [ \left( \alpha x - \beta x^3\right) p(x)\right ] + \frac{\sigma^2}{2} \frac{∂^2}{∂x^2} p(x) = 0 \, ,\]
which must satisfy the normalization condition:
\[\Delta t \, p(x) = 1\]
with the boundary conditions:
\[p(-2.2) = p(2.2) = 0\]
with Physics-Informed Neural Networks.
using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL, LineSearches
using Integrals, Cubature
using ModelingToolkit: Interval, infimum, supremum
# the example is taken from this article https://arxiv.org/abs/1910.10503
@parameters x
@variables p(..)
Dx = Differential(x)
Dxx = Differential(x)^2
α = 0.3
β = 0.5
_σ = 0.5
x_0 = -2.2
x_end = 2.2
eq = Dx((α * x - β * x^3) * p(x)) ~ (_σ^2 / 2) * Dxx(p(x))
# Initial and boundary conditions
bcs = [p(x_0) ~ 0.0, p(x_end) ~ 0.0]
# Space and time domains
domains = [x ∈ Interval(x_0, x_end)]
# Neural network
inn = 18
chain = Lux.Chain(Dense(1, inn, Lux.σ),
Dense(inn, inn, Lux.σ),
Dense(inn, inn, Lux.σ),
Dense(inn, 1))
lb = [x_0]
ub = [x_end]
function norm_loss_function(phi, θ, p)
function inner_f(x, θ)
0.01 * phi(x, θ) .- 1
end
prob = IntegralProblem(inner_f, lb, ub, θ)
norm2 = solve(prob, HCubatureJL(), reltol = 1e-8, abstol = 1e-8, maxiters = 10)
abs(norm2[1])
end
discretization = PhysicsInformedNN(chain,
QuadratureTraining(),
additional_loss = norm_loss_function)
@named pdesystem = PDESystem(eq, bcs, domains, [x], [p(x)])
prob = discretize(pdesystem, discretization)
phi = discretization.phi
sym_prob = NeuralPDE.symbolic_discretize(pdesystem, discretization)
pde_inner_loss_functions = sym_prob.loss_functions.pde_loss_functions
bcs_inner_loss_functions = sym_prob.loss_functions.bc_loss_functions
approx_derivative_loss_functions = sym_prob.loss_functions.bc_loss_functions
cb_ = function (p, l)
println("loss: ", l)
println("pde_losses: ", map(l_ -> l_(p.u), pde_inner_loss_functions))
println("bcs_losses: ", map(l_ -> l_(p.u), bcs_inner_loss_functions))
println("additional_loss: ", norm_loss_function(phi, p.u, nothing))
return false
end
res = Optimization.solve(
prob, BFGS(linesearch = BackTracking()), callback = cb_, maxiters = 600)
retcode: Failure
u: ComponentVector{Float64}(layer_1 = (weight = [7.337746981953101; 0.09536223511829722; … ; 3.3995657569532773; 5.059787896978759;;], bias = [4.75624502193597, 2.800743387768885, -1.8042135522039522, 10.029070806012689, -1.146315116241741, 10.93178850701685, -1.8212481551012936, -4.568098612374944, -0.5084665056574377, -0.47421259575683794, -1.6704317102780268, 6.502876352198524, 7.121734119004083, 0.9361666830848584, -8.509744990560238, 7.882043254746916, 6.3421196109535, 2.465671746488]), layer_2 = (weight = [2.451925482785846 -5.404324367726118 … -5.555934037649733 1.476944638029231; -0.0631126294003042 -3.23805851768152 … 0.00985415420471804 -0.47967432831648965; … ; -0.2653419536301569 -1.0718754451693362 … 2.9930084336526375 2.1073594378077485; -1.1864797661837196 -3.9047650845054704 … -0.2580666379196503 -0.5651293015957578], bias = [-1.4663404040635104, -1.07542691603594, -3.612675752015448, 3.565884296842627, 4.453301876927702, 1.95591067607892, 0.35583468563601606, -0.2949086678411602, -2.992333268533149, -3.4708798877780693, -4.115491691066978, 2.443675526389139, 1.9899782891778401, 2.830829879439487, -3.8652001262130296, 2.402711810070533, 4.401281833035798, 1.6621252457909814]), layer_3 = (weight = [-3.0775735853220523 0.29564513451226293 … 2.8715637011554067 3.035371466511115; 0.14157239135042493 -0.04304383009131031 … 5.430957826138111 0.5789715528410616; … ; -1.6340773279765115 -0.1323822376925713 … 0.22465822827135318 1.0330224760878612; -2.7182165691003317 1.269079706103027 … 8.912473031008926 5.328910324985589], bias = [0.7419772915297326, 4.530951055885517, 3.469395017817095, -2.1262581787854176, 2.500387652288854, -0.2509462119127095, 1.383127595537327, -4.756655548984505, -1.7797100988824675, 4.467967483609831, 1.154257911201549, 3.218792027564144, -0.4679417525262398, 2.05792060961859, 1.8214632261231676, -1.2062424380525496, -1.0482325454576817, 0.2790187563290962]), layer_4 = (weight = [17.87402728296952 23.26847878620312 … -7.0401393243773125 54.10549279289428], bias = [7.198064232278214]))
And some analysis:
using Plots
C = 142.88418699042 #fitting param
analytic_sol_func(x) = C * exp((1 / (2 * _σ^2)) * (2 * α * x^2 - β * x^4))
xs = [infimum(d.domain):0.01:supremum(d.domain) for d in domains][1]
u_real = [analytic_sol_func(x) for x in xs]
u_predict = [first(phi(x, res.u)) for x in xs]
plot(xs, u_real, label = "analytic")
plot!(xs, u_predict, label = "predict")