Fourier Neural Operators (FNOs)

FNOs are a subclass of Neural Operators that learn the learn the kernel $\Kappa_{\theta}$, parameterized on $\theta$ between function spaces:

\[(\Kappa_{\theta}u)(x) = \int_D \kappa_{\theta}(a(x), a(y), x, y) dy \quad \forall x \in D\]

The kernel makes up a block $v_t(x)$ which passes the information to the next block as:

\[v^{(t+1)}(x) = \sigma((W^{(t)}v^{(t)} + \Kappa^{(t)}v^{(t)})(x))\]

FNOs choose a specific kernel $\kappa(x,y) = \kappa(x-y)$, converting the kernel into a convolution operation, which can be efficiently computed in the fourier domain.

\[\begin{align*} (\Kappa_{\theta}u)(x) &= \int_D \kappa_{\theta}(x - y) dy \quad \forall x \in D\\ &= \mathcal{F}^{-1}(\mathcal{F}(\kappa_{\theta}) \mathcal{F}(u))(x) \quad \forall x \in D \end{align*}\]

where $\mathcal{F}$ denotes the fourier transform. Usually, not all the modes in the frequency domain are used with the higher modes often being truncated.

Usage

Let's try to learn the anti-derivative operator for

\[u(x) = sin(\alpha x)\]

That is, we want to learn

\[\mathcal{G} : u \rightarrow v \\\]

such that

\[v(x) = \frac{du}{dx} \quad \forall \; x \in [0, 2\pi], \; \alpha \in [0.5, 1]\]

Copy-pastable code

Click here to see copy-pastable code for this example!
using NeuralOperators, Lux, Random, Optimisers, Reactant

using CairoMakie, AlgebraOfGraphics
set_aog_theme!()
const AoG = AlgebraOfGraphics

rng = Random.default_rng()
Random.seed!(rng, 1234)

xdev = reactant_device()

batch_size = 128
m = 32

xrange = range(0, 2π; length=m) .|> Float32;
u_data = zeros(Float32, m, 1, batch_size);
α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size);
v_data = zeros(Float32, m, 1, batch_size);

for i in 1:batch_size
    u_data[:, 1, i] .= sin.(α[i] .* xrange)
    v_data[:, 1, i] .= -inv(α[i]) .* cos.(α[i] .* xrange)
end

fno = FourierNeuralOperator(gelu; chs=(1, 64, 64, 128, 1), modes=(16,))

ps, st = Lux.setup(rng, fno) |> xdev;
u_data = u_data |> xdev;
v_data = v_data |> xdev;
data = [(u_data, v_data)];

function train!(model, ps, st, data; epochs=10)
    losses = []
    tstate = Training.TrainState(model, ps, st, Adam(0.003f0))
    for _ in 1:epochs, (x, y) in data
        (_, loss, _, tstate) = Training.single_train_step!(
            AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false)
        )
        push!(losses, Float32(loss))
    end
    return losses
end

losses = train!(fno, ps, st, data; epochs=1000)

draw(
    AoG.data((; losses, iteration=1:length(losses))) *
    mapping(:iteration => "Iteration", :losses => "Loss (log10 scale)") *
    visual(Lines);
    axis=(; yscale=log10),
    figure=(; title="Using Fourier Neural Operator to learn the anti-derivative operator")
)
Example block output
using NeuralOperators, Lux, Random, Optimisers, Reactant

We will use Reactant.jl to accelerate the training process.

xdev = reactant_device()
(::MLDataDevices.ReactantDevice{Missing, Missing, Missing}) (generic function with 1 method)

Constructing training data

First, we construct our training data.

rng = Random.default_rng()
Random.TaskLocalRNG()

batch_size is the number of observations.

batch_size = 128
128

m is the length of a single observation, you can also interpret this as the size of the grid we're evaluating our function on.

m = 32
32

We instantiate the domain that the function operates on as a range from 0 to , whose length is the grid size.

xrange = range(0, 2π; length=m) .|> Float32;

Each value in the array here, α, will be the multiplicative factor on the input to the sine function.

α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size);

Now, we create our data arrays. We are storing all of the training data in a single array, in order to batch process them more efficiently.

u_data = zeros(Float32, m, 1, batch_size);
v_data = zeros(Float32, m, 1, batch_size);

and fill the data arrays with values. Here, u_data is

for i in 1:batch_size
    u_data[:, 1, i] .= sin.(α[i] .* xrange)
    v_data[:, 1, i] .= -inv(α[i]) .* cos.(α[i] .* xrange)
end

Creating the model

Finally, we get to the model itself. We instantiate a FourierNeuralOperator and provide it several parameters.

The first argument is the "activation function" for each neuron.

The keyword arguments are:

  • chs is a tuple, representing the layer sizes for each layer.
  • modes is a 1-tuple, where the number represents the number of Fourier modes that are preserved, and the size of the tuple represents the number of dimensions.
fno = FourierNeuralOperator(
    gelu;                    # activation function
    chs=(1, 64, 64, 128, 1), # channel weights
    modes=(16,),             # number of Fourier modes to retain
)
FourierNeuralOperator(
    model = Chain(
        layer_1 = Conv((1,), 1 => 64),  # 128 parameters
        layer_2 = Chain(
            layer_1 = OperatorKernel(
                layer = Parallel(
                    connection = Fix1(add_act, gelu_tanh),
                    layer_1 = Conv((1,), 64 => 64, use_bias=false),  # 4_096 parameters
                    layer_2 = Chain(
                        stabilizer = WrappedFunction(Base.Broadcast.BroadcastFunction(identity)),
                        conv_layer = OperatorConv(64 => 64, FourierTransform{ComplexF32}((16,), shift=false)),  # 65_536 parameters
                    ),
                ),
            ),
        ),
        layer_3 = Chain(
            layer_1 = Conv((1,), 64 => 128, gelu_tanh),  # 8_320 parameters
            layer_2 = Conv((1,), 128 => 1),  # 129 parameters
        ),
    ),
)         # Total: 78_209 parameters,
          #        plus 0 states.

Now, we set up the model. This function returns two things, a set of parameters and a set of states. Since the operator is "stateless", the states are empty and will remain so. The parameters are the weights of the neural network, and we will be modifying them in the training loop.

ps, st = Lux.setup(rng, fno) |> xdev;

We construct data as a vector of tuples (input, output). These are pre-batched, but for example if we had a lot of training data, we could dynamically load it, or create multiple batches.

u_data = u_data |> xdev;
v_data = v_data |> xdev;
data = [(u_data, v_data)];

Training the model

Now, we create a function to train the model. An "epoch" is basically a run over all input data, and the more epochs we have, the better the neural network gets!

function train!(model, ps, st, data; epochs=10)
    # The `losses` array is used only for visualization,
    # you don't actually need it to train.
    losses = []
    # Initialize a training state and an optimizer (Adam, in this case).
    tstate = Training.TrainState(model, ps, st, Adam(0.003f0))
    # Loop over epochs, then loop over each batch of training data, and step into the
    # training:
    for _ in 1:epochs
        for (x, y) in data
            (_, loss, _, tstate) = Training.single_train_step!(
                AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false)
            )
            push!(losses, Float32(loss))
        end
    end
    return losses, tstate.parameters, tstate.states
end
train! (generic function with 1 method)

Now we train our model!

losses, ps, st = @time train!(fno, ps, st, data; epochs=500)
(Any[1.2563593f0, 0.90409446f0, 0.75699973f0, 0.46846542f0, 0.27422562f0, 0.23793976f0, 0.30596852f0, 0.32843375f0, 0.24045478f0, 0.15608808f0  …  0.00018922474f0, 0.0001889902f0, 0.00018901388f0, 0.00018787259f0, 0.0001862567f0, 0.00018580753f0, 0.00018616868f0, 0.00018565395f0, 0.00018488613f0, 0.00018340204f0], (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.5475705;;; -0.75017637;;; -0.37807664;;; … ;;; -0.14128236;;; -0.403385;;; -0.45076394]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.32573214, 0.3022895, -0.3197351, -0.73144954, -0.14573541, 0.5738895, 0.08689028, -0.6988238, -0.67005813, 0.97262627  …  -0.8155921, 0.23888259, -0.043338586, -0.56448644, 0.65007764, -0.2453263, 0.8307237, -0.07527764, -0.908768, -0.2405491])), layer_2 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.023379723 0.065964945 … 0.076738 -0.03586938;;; -0.0297778 -0.08437341 … -0.18763216 0.04119166;;; -0.013440075 -0.13235192 … -0.054850593 -0.04422788;;; … ;;; 0.18017888 0.0042331647 … -0.20152472 0.13377978;;; 0.0035802622 0.053576533 … -0.08239168 0.1254433;;; -0.04151288 0.019680742 … 0.11149646 -0.118386604]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[-0.021934131f0 + 2.0023768f-7im -0.0424336f0 + 1.6258078f-5im … -0.005973052f0 + 3.8480944f-6im -0.026790405f0 + 1.0473344f-5im; -0.012679296f0 + 1.1713211f-5im -0.03184222f0 + 1.2957525f-5im … -0.0057540396f0 + 3.113571f-6im -0.01510022f0 + 2.5894963f-6im; … ; 0.023382938f0 + 6.966589f-6im -0.039712764f0 + 4.390141f-6im … 0.02562763f0 + 1.5898091f-5im 0.022522822f0 + 3.6745623f-6im; 0.017044565f0 + 1.5640622f-5im -0.08616319f0 + 1.0504971f-5im … 0.04075888f0 + 2.6861167f-6im 0.011302553f0 + 1.8418228f-7im;;; 0.0010578645f0 + 0.008173503f0im 0.001088694f0 + 0.00828619f0im … 0.0009854649f0 + 0.008220489f0im 0.0011261824f0 + 0.008248223f0im; 0.004994636f0 - 0.021874882f0im 0.0047570895f0 - 0.021937937f0im … 0.0049126563f0 - 0.021898758f0im 0.0047767702f0 - 0.02193707f0im; … ; -0.0056345095f0 + 0.022034789f0im -0.005214879f0 + 0.02156441f0im … -0.005434951f0 + 0.02184235f0im -0.0053211944f0 + 0.02166892f0im; -0.0016066749f0 - 0.007741334f0im -0.0018503855f0 - 0.007616585f0im … -0.0018431942f0 - 0.007628411f0im -0.001834282f0 - 0.0076426035f0im;;; 0.0060306396f0 + 0.040913455f0im 0.005928658f0 + 0.040620863f0im … 0.0059285345f0 + 0.040737376f0im 0.005921178f0 + 0.040705577f0im; 0.003982953f0 - 0.05619096f0im 0.0037886493f0 - 0.0555669f0im … 0.0039222697f0 - 0.055913728f0im 0.003839012f0 - 0.055746775f0im; … ; -0.050137088f0 + 0.033634152f0im -0.048544582f0 + 0.033131275f0im … -0.04939949f0 + 0.033500712f0im -0.048999157f0 + 0.033235356f0im; 0.01980696f0 + 0.03757678f0im 0.01964577f0 + 0.03755068f0im … 0.01974257f0 + 0.037508667f0im 0.019693766f0 + 0.037522513f0im;;; … ;;; 0.19949539f0 + 0.02269205f0im 0.19540335f0 + 0.022022966f0im … 0.19728665f0 + 0.022372799f0im 0.19634506f0 + 0.022262175f0im; -0.11241422f0 - 0.015110133f0im -0.11037509f0 - 0.014709894f0im … -0.11133609f0 - 0.0149184f0im -0.11074505f0 - 0.014849446f0im; … ; 0.20264105f0 - 0.14149907f0im 0.2022924f0 - 0.14214367f0im … 0.19946949f0 - 0.14004366f0im 0.20042361f0 - 0.14081374f0im; 0.04671992f0 + 0.015525259f0im 0.04685081f0 + 0.015319447f0im … 0.046311006f0 + 0.014891257f0im 0.04649053f0 + 0.015085069f0im;;; 0.19830206f0 - 0.024239589f0im 0.19450776f0 - 0.023574125f0im … 0.19631436f0 - 0.023856884f0im 0.19544163f0 - 0.023723792f0im; -0.12309463f0 - 0.022266857f0im -0.12031829f0 - 0.021531308f0im … -0.12161669f0 - 0.021902625f0im -0.12100527f0 - 0.021771569f0im; … ; 0.23709762f0 + 0.073497474f0im 0.2361705f0 + 0.07157281f0im … 0.23447336f0 + 0.07130495f0im 0.23551121f0 + 0.07182358f0im; 0.061563484f0 - 0.033076756f0im 0.061192777f0 - 0.03293865f0im … 0.060360074f0 - 0.032073095f0im 0.060783744f0 - 0.03252597f0im;;; 0.20825776f0 - 0.1251151f0im 0.20531245f0 - 0.12331691f0im … 0.20678686f0 - 0.12427339f0im 0.20633435f0 - 0.12385817f0im; -0.11652947f0 + 0.018527308f0im -0.11435148f0 + 0.017806428f0im … -0.11537044f0 + 0.018129509f0im -0.11478294f0 + 0.018084506f0im; … ; 0.17309548f0 + 0.077629864f0im 0.17262216f0 + 0.07838359f0im … 0.17095685f0 + 0.07753413f0im 0.17163518f0 + 0.07782968f0im; 0.08450817f0 - 0.15523686f0im 0.083763234f0 - 0.15468846f0im … 0.0834361f0 - 0.15351714f0im 0.08378009f0 - 0.15427832f0im]),))),), layer_3 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.13240194 0.169463 … -0.026417507 -0.05224111;;; 0.14696611 -0.08533586 … -0.11697945 -0.13042495;;; 0.05904611 -0.0975806 … -0.1377364 -0.17645033;;; … ;;; 0.23790939 0.0881269 … -0.18012255 -0.18658745;;; -0.16654582 0.17060001 … 0.027283419 -0.07382685;;; -0.14449495 0.12375862 … -0.16142659 -0.10717205]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.052595496, -0.079864986, 0.049376357, -0.039686598, 0.018115014, 0.010044209, 0.049358737, -0.046467207, -0.07993954, -0.02044194  …  0.025258616, -0.040614624, 0.0069566113, -0.09788293, -0.11799793, 0.08962666, 0.014088821, -0.08383356, -0.11321426, -0.06794172])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.14386532 0.10532405 … -0.1849379 -0.0823957;;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.027839204])))), (layer_1 = NamedTuple(), layer_2 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())),), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))

Applying the model

Let's try to actually apply this model using some input data.

input_data = u_data[:, 1, 1]
32-element Reactant.ConcretePJRTArray{Float32,1}:
  0.0
  0.11259386
  0.22375578
  0.33207196
  0.436165
  0.5347109
  0.62645644
  0.7102349
  0.7849807
  0.84974325
  ⋮
  0.51964104
  0.4201385
  0.31529233
  0.20643656
  0.09495513
 -0.017733686
 -0.13019721
 -0.2410049
 -0.3487473

This is our input data. It's currently one-dimensional, but our neural network expects input in batched form, so we simply reshape it (a no-cost operation) to a 3d array with singleton dimensions.

reshaped_input = reshape(input_data, length(input_data), 1, 1)
32×1×1 reshape(::Reactant.ConcretePJRTArray{Float32,1}, 32, 1, 1) with eltype Float32:
[:, :, 1] =
  0.0
  0.11259386
  0.22375578
  0.33207196
  0.436165
  0.5347109
  0.62645644
  0.7102349
  0.7849807
  0.84974325
  ⋮
  0.51964104
  0.4201385
  0.31529233
  0.20643656
  0.09495513
 -0.017733686
 -0.13019721
 -0.2410049
 -0.3487473

Now we can pass this to Lux.apply (@jit is used to run the function with Reactant.jl):

output_data, st = @jit Lux.apply(fno, reshaped_input, ps, st)
(Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-1.7649521; -1.7795782; … ; 1.7657027; 1.6976722;;;]), (layer_1 = NamedTuple(), layer_2 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())),), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))

and plot it:

using CairoMakie, AlgebraOfGraphics
const AoG = AlgebraOfGraphics
AoG.set_aog_theme!()

f, a, p = lines(dropdims(Array(reshaped_input); dims=(2, 3)); label="u")
lines!(a, dropdims(Array(output_data); dims=(2, 3)); label="Predicted")
lines!(a, Array(v_data)[:, 1, 1]; label="Expected")
axislegend(a)
# Compute the absolute error and plot that too,
# on a separate axis.
absolute_error = Array(v_data)[:, 1, 1] .- dropdims(Array(output_data); dims=(2, 3))
a2, p2 = lines(f[2, 1], absolute_error; axis=(; ylabel="Error"))
rowsize!(f.layout, 2, Aspect(1, 1 / 8))
linkxaxes!(a, a2)
f
Example block output