Physics Informed Neural Operator for ODEs
This tutorial provides an example of how to use the Physics Informed Neural Operator (PINO) for solving a family of parametric ordinary differential equations (ODEs).
Operator Learning for a family of parametric ODEs
In this section, we will define a parametric ODE and then learn it with a PINO using PINOODE
. The PINO will be trained to learn the mapping from the parameters of the ODE to its solution.
using Test
using OptimizationOptimisers
using Lux
using Statistics, Random
using NeuralOperators
using NeuralPDE
# Define the parametric ODE equation
equation = (u, p, t) -> p[1] * cos(p[2] * t) + p[3]
tspan = (0.0, 1.0)
u0 = 1.0
prob = ODEProblem(equation, u0, tspan)
# Set the number of parameters for the ODE
num_params = 3
# Define the DeepONet architecture for the PINO
deeponet = NeuralOperators.DeepONet(
Chain(
Dense(num_params => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast), Dense(10 => 10)),
Chain(Dense(1 => 10, Lux.tanh_fast), Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast)))
# Define the bounds for the parameters
bounds = [(1.0, pi), (1.0, 2.0), (2.0, 3.0)]
number_of_parameter_samples = 50
# Define the training strategy
strategy = StochasticTraining(20)
# Define the optimizer
opt = OptimizationOptimisers.Adam(0.03)
# Define `PINNODE`
alg = PINOODE(deeponet, opt, bounds, num_params; strategy = strategy)
# Solve the ODE problem using the PINOODE algorithm
sol = solve(prob, alg, verbose = false, maxiters = 4000)
retcode: Success
Interpolation: Trained neural network interpolation
t: 1×20×1 Array{Float64, 3}:
[:, :, 1] =
0.304311 0.828689 0.973667 0.831179 … 0.307964 0.431287 0.0340639
u: 20×3 Matrix{Float64}:
2.4408 2.17469 2.4306
4.60305 3.81871 4.28785
5.07459 4.15705 4.61846
4.61165 3.82494 4.29417
1.09643 1.10913 1.15224
1.86848 1.72296 1.88849
1.32673 1.29272 1.37222
4.52731 3.76371 4.23176
3.43619 2.94852 3.34196
4.92126 4.04774 4.5146
3.49238 2.9915 3.39116
2.27703 2.04577 2.27608
3.10527 2.6936 3.04634
5.10911 4.18157 4.64135
1.34973 1.31103 1.39417
5.14001 4.2035 4.66167
2.65075 2.33945 2.62746
2.4574 2.18774 2.44622
3.01078 2.62032 2.96032
1.16696 1.16539 1.21963
Now let's compare the prediction from the learned operator with the ground truth solution which is obtained by analytic solution of the parametric ODE.
using Plots
function get_trainset(bounds, tspan, number_of_parameters, dt)
p_ = [range(start = b[1], length = number_of_parameters, stop = b[2]) for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...)
t_ = collect(tspan[1]:dt:tspan[2])
t = collect(reshape(t_, 1, size(t_, 1), 1))
(p, t)
end
# Compute the ground truth solution for each parameter
ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) + p[3] * t
function ground_solution_f(p, t)
reduce(hcat,
[[ground_solution(u0, p[:, i], t[j]) for j in axes(t, 2)] for i in axes(p, 2)])
end
# generate the solution with new parameters for test the model
(p, t) = get_trainset(bounds, tspan, 50, 0.025)
# compute the ground truth solution
ground_solution_ = ground_solution_f(p, t)
# predict the solution with the PINO model
predict = sol.interp(p, t)
# calculate the errors between the ground truth solution and the predicted solution
errors = ground_solution_ - predict
# calculate the mean error and the standard deviation of the errors
mean_error = mean(errors)
# calculate the standard deviation of the errors
std_error = std(errors)
p, t = get_trainset(bounds, tspan, 100, 0.01)
ground_solution_ = ground_solution_f(p, t)
predict = sol.interp(p, t)
errors = ground_solution_ - predict
mean_error = mean(errors)
std_error = std(errors)
# Plot the predicted solution and the ground truth solution as a filled contour plot
# predict, represents the predicted solution for each parameter value and time
plot(predict, linetype = :contourf)
plot!(ground_solution_, linetype = :contourf)
# 'i' is the index of the parameter 'p' in the dataset
i = 20
# 'predict' is the predicted solution from the PINO model
plot(predict[:, i], label = "Predicted")
# 'ground' is the ground truth solution
plot!(ground_solution_[:, i], label = "Ground truth")