DeepONets

DeepONets are another class of networks that learn the mapping between two function spaces by encoding the input function space and the location of the output space. The latent code of the input space is then projected on the location laten code to give the output. This allows the network to learn the mapping between two functions defined on different spaces.

\[\begin{align*} u(y) \xrightarrow{\text{branch}} & \; b \\ & \quad \searrow\\ &\quad \quad \mathcal{G}_{\theta} u(y) = \sum_k b_k t_k \\ & \quad \nearrow \\ y \; \; \xrightarrow{\text{trunk}} \; \; & t \end{align*}\]

Usage

Let's try to learn the anti-derivative operator for

\[u(x) = sin(\alpha x)\]

That is, we want to learn

\[\mathcal{G} : u \rightarrow v \\\]

such that

\[v(x) = \frac{du}{dx} \quad \forall \; x \in [0, 2\pi], \; \alpha \in [0.5, 1]\]

Copy-pastable code

using NeuralOperators, Lux, Random, Optimisers, Reactant

using CairoMakie, AlgebraOfGraphics
set_aog_theme!()
const AoG = AlgebraOfGraphics

rng = Random.default_rng()
Random.seed!(rng, 1234)

xdev = reactant_device()

eval_points = 1
batch_size = 64
dim_y = 1
m = 32

xrange = range(0, 2π; length=m) .|> Float32
α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size)

u_data = zeros(Float32, m, batch_size)
y_data = rand(rng, Float32, 1, eval_points) .* Float32(2π)
v_data = zeros(Float32, eval_points, batch_size)

for i in 1:batch_size
    u_data[:, i] .= sin.(α[i] .* xrange)
    v_data[:, i] .= -inv(α[i]) .* cos.(α[i] .* y_data[1, :])
end

deeponet = DeepONet(
    Chain(Dense(m => 8, σ), Dense(8 => 8, σ), Dense(8 => 8, σ)),
    Chain(Dense(1 => 4, σ), Dense(4 => 8, σ))
)

ps, st = Lux.setup(rng, deeponet) |> xdev;

u_data = u_data |> xdev;
y_data = y_data |> xdev;
v_data = v_data |> xdev;
data = [((u_data, y_data), v_data)];

function train!(model, ps, st, data; epochs=10)
    losses = []
    tstate = Training.TrainState(model, ps, st, Adam(0.001f0))
    for _ in 1:epochs, (x, y) in data
        (_, loss, _, tstate) = Training.single_train_step!(
            AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false)
        )
        push!(losses, Float32(loss))
    end
    return losses
end

losses = train!(deeponet, ps, st, data; epochs=1000)

draw(
    AoG.data((; losses, iteration=1:length(losses))) *
    mapping(:iteration => "Iteration", :losses => "Loss (log10 scale)") *
    visual(Lines);
    axis=(; yscale=log10),
    figure=(; title="Using DeepONet to learn the anti-derivative operator")
)
Example block output