DeepONet for 1D Poisson Equation

Mathematical Formulation

Problem Statement

We consider the one-dimensional Poisson equation on the domain $\Omega = [0,1]$:

\[-\frac{d^2u(x)}{dx^2} = f(x), \quad x \in [0,1]\]

subject to homogeneous Dirichlet boundary conditions:

\[u(0) = u(1) = 0\]

In our specific case, the forcing function $f(x)$ is parameterized by $\alpha$:

\[f(x) = \alpha \sin(\pi x)\]

Analytical Solution

\[u(x) = \frac{\alpha}{\pi^2} \sin(\pi x)\]

The DeepONet architecture consists of:

  • Branch network: Processes the discretized input function $f(x)$
  • Trunk network: Processes the spatial coordinates $x$
  • Output: $u(x) \approx \sum_{k=1}^p b_k(f) t_k(x)$, where $b_k$ are outputs from the branch network and $t_k$ are outputs from the trunk network

Implementation

using NeuralOperators, Lux, Random, Optimisers, Reactant, Statistics

using CairoMakie, AlgebraOfGraphics
const AoG = AlgebraOfGraphics
AoG.set_aog_theme!()

const cdev = cpu_device()
const xdev = reactant_device(; force=true)

rng = Random.default_rng()
Random.seed!(rng, 42)

# Problem setup
m = 64
data_size = 128
xrange = Float32.(range(0, 1; length=m))

forcing_function(x, α) = α * sinpi.(x)
poisson_solution(x, α) = (α / Float32(π)^2) * sinpi.(x)

# Generate training data
α_train = 0.5f0 .+ 0.8f0 .* rand(Float32, data_size)  # α ∈ [0.5, 1.3]
f_data = stack(Base.Fix1(forcing_function, xrange), α_train)
u_data = stack(Base.Fix1(poisson_solution, xrange), α_train)
x_data = reshape(xrange, 1, m)
max_u = maximum(abs.(u_data))
u_data ./= max_u

# Define DeepONet
deeponet = DeepONet(
    Chain(Dense(m => 64, tanh), Dense(64 => 64, tanh), Dense(64 => 32, tanh)),  # Branch
    Chain(Dense(1 => 32, tanh), Dense(32 => 32, tanh))                          # Trunk
)

ps, st = Lux.setup(Random.default_rng(), deeponet) |> xdev;

# Training
function train_model!(model, ps, st, data; epochs=3000)
    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
    losses = Float32[]

    for epoch in 1:epochs
        _, loss, _, train_state = Training.single_train_step!(
            AutoEnzyme(), MSELoss(), data, train_state; return_gradients=Val(false)
        )
        epoch % 500 == 0 && println("Epoch: $epoch, Loss: $loss")
        push!(losses, loss)
    end

    return train_state.parameters, train_state.states, losses
end

f_data = f_data |> xdev;
x_data = x_data |> xdev;
u_data = u_data |> xdev;
data = ((f_data, x_data), u_data)
ps_trained, st_trained, losses = train_model!(deeponet, ps, st, data)

# Prediction function
function predict(model, f_input, x_input, ps, st)
    pred, _ = model((f_input, x_input), ps, st)
    return vec(pred .* max_u)
end

compiled_predict_fn = Reactant.with_config(; dot_general_precision=PrecisionConfig.HIGH) do
    @compile predict(deeponet, f_data[:, 1:1], x_data, ps_trained, st_trained)
end

# Testing and visualization
begin
    results = Float32[]
    labels = AbstractString[]
    abs_errors = Float32[]
    x_values = Float32[]
    x_values2 = Float32[]
    alphas = AbstractString[]
    alphas2 = AbstractString[]

    for (i, α) in enumerate([0.6f0, 0.8f0, 1.2f0])
        f_test = reshape(forcing_function(xrange, α), :, 1)
        u_pred = compiled_predict_fn(
            deeponet, xdev(f_test), x_data, ps_trained, st_trained
        ) |> cdev
        u_true = reshape(poisson_solution(xrange, α), :, 1)

        l2_error = sqrt(mean(abs2, u_pred .- u_true))
        rel_error = l2_error / sqrt(mean(abs2, u_true)) * 100

        text = L"$ \alpha = %$(α) $ (Rel. Error: $ %$(round(rel_error, digits=2))% $)"

        append!(results, vec(u_pred))
        append!(labels, repeat(["Predictions"], length(xrange)))
        append!(results, vec(u_true))
        append!(labels, repeat(["Ground Truth"], length(xrange)))
        append!(x_values, repeat(vec(xrange), 2))
        append!(alphas, repeat([text], length(xrange) * 2))

        append!(abs_errors, abs.(vec(u_pred .- u_true)))
        append!(x_values2, vec(xrange))
        append!(alphas2, repeat([text], length(xrange)))
    end
end

plot_data = (; results, abs_errors, x_values, alphas, labels, x_values2, alphas2)

begin
    fig = Figure(;
        size=(1024, 512),
        title="DeepONet Results for 1D Poisson Equation",
        titlesize=25
    )

    axis_common = (;
        xlabelsize=20, ylabelsize=20, titlesize=20, xticklabelsize=20, yticklabelsize=20
    )

    axs1 = draw!(
        fig[1, 1],
        AoG.data(plot_data) *
        mapping(
            :x_values => L"x",
            :results => L"u(x)";
            color=:labels => "",
            col=:alphas => "",
            linestyle=:labels => "",
        ) *
        visual(Lines; linewidth=4),
        scales(; Color=(; palette=[:orange, :blue]), LineStyle = (; palette = [:solid, :dash]));
        axis=merge(axis_common, (; xlabel="")),
    )
    for ax in axs1
        hidexdecorations!(ax; grid=false)
    end

    axislegend(
        axs1[1, 1].axis,
        [
            LineElement(; linestyle=:solid, color=:orange),
            LineElement(; linestyle=:dash, color=:blue),
        ],
        ["Ground Truth", "Predictions"],
        labelsize=20,
    )

    axs2 = draw!(
        fig[2, 1],
        AoG.data(plot_data) *
        mapping(
            :x_values2 => L"x",
            :abs_errors => L"|u(x) - u(x_{true})|";
            col=:alphas2 => "",
        ) *
        visual(Lines; linewidth=4, color=:green);
        axis=merge(axis_common, (; titlevisible=false)),
    )

    fig
end
Example block output