Partial Differential Equation (PDE) Constrained Optimization
This example uses a prediction model to optimize the one-dimensional Heat Equation. (Step-by-step description below)
using DelimitedFiles,Plots
using DifferentialEquations, Optimization, OptimizationPolyalgorithms, Zygote, OptimizationOptimJL
# Problem setup parameters:
Lx = 10.0
x = 0.0:0.01:Lx
dx = x[2] - x[1]
Nx = size(x)
u0 = exp.(-(x.-3.0).^2) # I.C
## Problem Parameters
p = [1.0,1.0] # True solution parameters
xtrs = [dx,Nx] # Extra parameters
dt = 0.40*dx^2 # CFL condition
t0, tMax = 0.0 ,1000*dt
tspan = (t0,tMax)
t = t0:dt:tMax;
## Definition of Auxiliary functions
function ddx(u,dx)
"""
2nd order Central difference for 1st degree derivative
"""
return [[zero(eltype(u))] ; (u[3:end] - u[1:end-2]) ./ (2.0*dx) ; [zero(eltype(u))]]
end
function d2dx(u,dx)
"""
2nd order Central difference for 2nd degree derivative
"""
return [[zero(eltype(u))]; (u[3:end] - 2.0.*u[2:end-1] + u[1:end-2]) ./ (dx^2); [zero(eltype(u))]]
end
## ODE description of the Physics:
function heat(u,p,t)
# Model parameters
a0, a1 = p
dx,Nx = xtrs #[1.0,3.0,0.125,100]
return 2.0*a0 .* u + a1 .* d2dx(u, dx)
end
# Testing Solver on linear PDE
prob = ODEProblem(heat,u0,tspan,p)
sol = solve(prob,Tsit5(), dt=dt,saveat=t);
plot(x, sol.u[1], lw=3, label="t0", size=(800,500))
plot!(x, sol.u[end],lw=3, ls=:dash, label="tMax")
ps = [0.1, 0.2]; # Initial guess for model parameters
function predict(θ)
Array(solve(prob,Tsit5(),p=θ,dt=dt,saveat=t))
end
## Defining Loss function
function loss(θ)
pred = predict(θ)
l = predict(θ) - sol
return sum(abs2, l), pred # Mean squared error
end
l,pred = loss(ps)
size(pred), size(sol), size(t) # Checking sizes
LOSS = [] # Loss accumulator
PRED = [] # prediction accumulator
PARS = [] # parameters accumulator
callback = function (θ,l,pred) #callback function to observe training
display(l)
append!(PRED, [pred])
append!(LOSS, l)
append!(PARS, [θ])
false
end
callback(ps,loss(ps)...) # Testing callback function
# Let see prediction vs. Truth
scatter(sol[:,end], label="Truth", size=(800,500))
plot!(PRED[end][:,end], lw=2, label="Prediction")
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
res = Optimization.solve(optprob, PolyOpt(), callback = callback)
@show res.u # returns [0.999999999613485, 0.9999999991343996]
Step-by-step Description
Load Packages
using DelimitedFiles,Plots
using DifferentialEquations, Optimization, OptimizationPolyalgorithms,
Zygote, OptimizationOptimJL
Parameters
First, we setup the 1-dimensional space over which our equations will be evaluated. x
spans from 0.0 to 10.0 in steps of 0.01; t
spans from 0.00 to 0.04 in steps of 4.0e-5.
# Problem setup parameters:
Lx = 10.0
x = 0.0:0.01:Lx
dx = x[2] - x[1]
Nx = size(x)
u0 = exp.(-(x.-3.0).^2) # I.C
## Problem Parameters
p = [1.0,1.0] # True solution parameters
xtrs = [dx,Nx] # Extra parameters
dt = 0.40*dx^2 # CFL condition
t0, tMax = 0.0 ,1000*dt
tspan = (t0,tMax)
t = t0:dt:tMax;
0.0:4.0e-5:0.04
In plain terms, the quantities that were defined are:
x
(toLx
) spans the specified 1D spacedx
= distance between two pointsNx
= total size of spaceu0
= initial conditionp
= true solutionxtrs
= convenient grouping ofdx
andNx
into Arraydt
= time distance between two pointst
(t0
totMax
) spans the specified time frametspan
= span oft
Auxiliary Functions
We then define two functions to compute the derivatives numerically. The Central Difference is used in both the 1st and 2nd degree derivatives.
## Definition of Auxiliary functions
function ddx(u,dx)
"""
2nd order Central difference for 1st degree derivative
"""
return [[zero(eltype(u))] ; (u[3:end] - u[1:end-2]) ./ (2.0*dx) ; [zero(eltype(u))]]
end
function d2dx(u,dx)
"""
2nd order Central difference for 2nd degree derivative
"""
return [[zero(eltype(u))]; (u[3:end] - 2.0.*u[2:end-1] + u[1:end-2]) ./ (dx^2); [zero(eltype(u))]]
end
d2dx (generic function with 1 method)
Heat Differential Equation
Next, we setup our desired set of equations in order to define our problem.
## ODE description of the Physics:
function heat(u,p,t)
# Model parameters
a0, a1 = p
dx,Nx = xtrs #[1.0,3.0,0.125,100]
return 2.0*a0 .* u + a1 .* d2dx(u, dx)
end
heat (generic function with 1 method)
Solve and Plot Ground Truth
We then solve and plot our partial differential equation. This is the true solution which we will compare to further on.
# Testing Solver on linear PDE
prob = ODEProblem(heat,u0,tspan,p)
sol = solve(prob,Tsit5(), dt=dt,saveat=t);
plot(x, sol.u[1], lw=3, label="t0", size=(800,500))
plot!(x, sol.u[end],lw=3, ls=:dash, label="tMax")
Building the Prediction Model
Now we start building our prediction model to try to obtain the values p
. We make an initial guess for the parameters and name it ps
here. The predict
function is a non-linear transformation in one layer using solve
. If unfamiliar with the concept, refer to here.
ps = [0.1, 0.2]; # Initial guess for model parameters
function predict(θ)
Array(solve(prob,Tsit5(),p=θ,dt=dt,saveat=t))
end
predict (generic function with 1 method)
Train Parameters
Training our model requires a loss function, an optimizer and a callback function to display the progress.
Loss
We first make our predictions based on the current values of our parameters ps
, then take the difference between the predicted solution and the truth above. For the loss, we use the Mean squared error.
## Defining Loss function
function loss(θ)
pred = predict(θ)
l = predict(θ) - sol
return sum(abs2, l), pred # Mean squared error
end
l,pred = loss(ps)
size(pred), size(sol), size(t) # Checking sizes
((1001, 1001), (1001, 1001), (1001,))
Optimizer
The optimizers ADAM
with a learning rate of 0.01 and BFGS
are directly passed in training (see below)
Callback
The callback function displays the loss during training. We also keep a history of the loss, the previous predictions and the previous parameters with LOSS
, PRED
and PARS
accumulators.
LOSS = [] # Loss accumulator
PRED = [] # prediction accumulator
PARS = [] # parameters accumulator
callback = function (θ,l,pred) #callback function to observe training
display(l)
append!(PRED, [pred])
append!(LOSS, l)
append!(PARS, [θ])
false
end
callback(ps,loss(ps)...) # Testing callback function
false
Plotting Prediction vs Ground Truth
The scatter points plotted here are the ground truth obtained from the actual solution we solved for above. The solid line represents our prediction. The goal is for both to overlap almost perfectly when the PDE finishes its training and the loss is close to 0.
# Let see prediction vs. Truth
scatter(sol[:,end], label="Truth", size=(800,500))
plot!(PRED[end][:,end], lw=2, label="Prediction")
Train
The parameters are trained using Optimization.solve
and adjoint sensitivities. The resulting best parameters are stored in res
and res.u
returns the parameters that minimizes the cost function.
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
res = Optimization.solve(optprob, PolyOpt(), callback = callback)
@show res.u # returns [0.999999999613485, 0.9999999991343996]
We successfully predict the final ps
to be equal to [0.999999999999975, 1.0000000000000213] vs the true solution of p
= [1.0, 1.0]