Optimizing Parameters (Solving Inverse Problems) with Physics-Informed Neural Networks (PINNs)

Consider a Lorenz System,

\[\begin{align*} \frac{\mathrm{d} x}{\mathrm{d}t} &= \sigma (y -x) \, ,\\ \frac{\mathrm{d} y}{\mathrm{d}t} &= x (\rho - z) - y \, ,\\ \frac{\mathrm{d} z}{\mathrm{d}t} &= x y - \beta z \, ,\\ \end{align*}\]

with Physics-Informed Neural Networks. Now we would consider the case where we want to optimize the parameters $\sigma$, $\beta$, and $\rho$.

We start by defining the problem,

using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL, OrdinaryDiffEq,
      Plots, LineSearches
using ModelingToolkit: Interval, infimum, supremum
@parameters t, σ_, β, ρ
@variables x(..), y(..), z(..)
Dt = Differential(t)
eqs = [Dt(x(t)) ~ σ_ * (y(t) - x(t)),
    Dt(y(t)) ~ x(t) * (ρ - z(t)) - y(t),
    Dt(z(t)) ~ x(t) * y(t) - β * z(t)]

bcs = [x(0) ~ 1.0, y(0) ~ 0.0, z(0) ~ 0.0]
domains = [t ∈ Interval(0.0, 1.0)]
dt = 0.01
0.01

And the neural networks as,

input_ = length(domains)
n = 8
chain1 = Lux.Chain(Dense(input_, n, Lux.σ), Dense(n, n, Lux.σ), Dense(n, n, Lux.σ),
    Dense(n, 1))
chain2 = Lux.Chain(Dense(input_, n, Lux.σ), Dense(n, n, Lux.σ), Dense(n, n, Lux.σ),
    Dense(n, 1))
chain3 = Lux.Chain(Dense(input_, n, Lux.σ), Dense(n, n, Lux.σ), Dense(n, n, Lux.σ),
    Dense(n, 1))
Chain(
    layer_1 = Dense(1 => 8, σ),         # 16 parameters
    layer_2 = Dense(8 => 8, σ),         # 72 parameters
    layer_3 = Dense(8 => 8, σ),         # 72 parameters
    layer_4 = Dense(8 => 1),            # 9 parameters
)         # Total: 169 parameters,
          #        plus 0 states.

We will add another loss term based on the data that we have to optimize the parameters.

Here we simply calculate the solution of the Lorenz system with OrdinaryDiffEq.jl based on the adaptivity of the ODE solver. This is used to introduce non-uniformity to the time series.

function lorenz!(du, u, p, t)
    du[1] = 10.0 * (u[2] - u[1])
    du[2] = u[1] * (28.0 - u[3]) - u[2]
    du[3] = u[1] * u[2] - (8 / 3) * u[3]
end

u0 = [1.0; 0.0; 0.0]
tspan = (0.0, 1.0)
prob = ODEProblem(lorenz!, u0, tspan)
sol = solve(prob, Tsit5(), dt = 0.1)
ts = [infimum(d.domain):0.01:supremum(d.domain) for d in domains][1]
function getData(sol)
    data = []
    us = hcat(sol(ts).u...)
    ts_ = hcat(sol(ts).t...)
    return [us, ts_]
end
data = getData(sol)

(u_, t_) = data
len = length(data[2])
101

Then we define the additional loss function additional_loss(phi, θ , p), the function has three arguments:

  • phi the trial solution
  • θ the parameters of neural networks
  • the hyperparameters p .

For a Lux neural network, the composed function will present itself as having θ as a ComponentArray subsets θ.x, which can also be dereferenced like θ[:x]. Thus, the additional loss looks like:

depvars = [:x, :y, :z]
function additional_loss(phi, θ, p)
    return sum(sum(abs2, phi[i](t_, θ[depvars[i]]) .- u_[[i], :]) / len for i in 1:1:3)
end
additional_loss (generic function with 1 method)

Then finally defining and optimizing using the PhysicsInformedNN interface.

discretization = NeuralPDE.PhysicsInformedNN([chain1, chain2, chain3],
    NeuralPDE.QuadratureTraining(; abstol = 1e-6, reltol = 1e-6, batch = 200), param_estim = true,
    additional_loss = additional_loss)
@named pde_system = PDESystem(eqs, bcs, domains, [t], [x(t), y(t), z(t)], [σ_, ρ, β],
    defaults = Dict([p .=> 1.0 for p in [σ_, ρ, β]]))
prob = NeuralPDE.discretize(pde_system, discretization)
callback = function (p, l)
    println("Current loss is: $l")
    return false
end
res = Optimization.solve(prob, BFGS(linesearch = BackTracking()); maxiters = 1000)
p_ = res.u[(end - 2):end] # p_ = [9.93, 28.002, 2.667]
3-element Vector{Float64}:
  9.617285395789576
 27.87829363462226
  2.4836115593811994

And then finally some analysis by plotting.

minimizers = [res.u.depvar[depvars[i]] for i in 1:3]
ts = [infimum(d.domain):(0.001):supremum(d.domain) for d in domains][1]
u_predict = [[discretization.phi[i]([t], minimizers[i])[1] for t in ts] for i in 1:3]
plot(sol)
plot!(ts, u_predict, label = ["x(t)" "y(t)" "z(t)"])
Example block output