Fourier Neural Operators (FNOs)
FNOs are a subclass of Neural Operators that learn the learn the kernel $\Kappa_{\theta}$, parameterized on $\theta$ between function spaces:
\[(\Kappa_{\theta}u)(x) = \int_D \kappa_{\theta}(a(x), a(y), x, y) dy \quad \forall x \in D\]
The kernel makes up a block $v_t(x)$ which passes the information to the next block as:
\[v^{(t+1)}(x) = \sigma((W^{(t)}v^{(t)} + \Kappa^{(t)}v^{(t)})(x))\]
FNOs choose a specific kernel $\kappa(x,y) = \kappa(x-y)$, converting the kernel into a convolution operation, which can be efficiently computed in the fourier domain.
\[\begin{align*} (\Kappa_{\theta}u)(x) &= \int_D \kappa_{\theta}(x - y) dy \quad \forall x \in D\\ &= \mathcal{F}^{-1}(\mathcal{F}(\kappa_{\theta}) \mathcal{F}(u))(x) \quad \forall x \in D \end{align*}\]
where $\mathcal{F}$ denotes the fourier transform. Usually, not all the modes in the frequency domain are used with the higher modes often being truncated.
Usage
Let's try to learn the anti-derivative operator for
\[u(x) = sin(\alpha x)\]
That is, we want to learn
\[\mathcal{G} : u \rightarrow v \\\]
such that
\[v(x) = \frac{du}{dx} \quad \forall \; x \in [0, 2\pi], \; \alpha \in [0.5, 1]\]
Copy-pastable code
Click here to see copy-pastable code for this example!
using NeuralOperators, Lux, Random, Optimisers, Reactant
using CairoMakie, AlgebraOfGraphics
set_aog_theme!()
const AoG = AlgebraOfGraphics
rng = Random.default_rng()
Random.seed!(rng, 1234)
xdev = reactant_device()
batch_size = 128
m = 32
xrange = range(0, 2π; length=m) .|> Float32;
u_data = zeros(Float32, m, 1, batch_size);
α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size);
v_data = zeros(Float32, m, 1, batch_size);
for i in 1:batch_size
u_data[:, 1, i] .= sin.(α[i] .* xrange)
v_data[:, 1, i] .= -inv(α[i]) .* cos.(α[i] .* xrange)
end
fno = FourierNeuralOperator(gelu; chs=(1, 64, 64, 128, 1), modes=(16,))
ps, st = Lux.setup(rng, fno) |> xdev;
u_data = u_data |> xdev;
v_data = v_data |> xdev;
data = [(u_data, v_data)];
function train!(model, ps, st, data; epochs=10)
losses = []
tstate = Training.TrainState(model, ps, st, Adam(0.003f0))
for _ in 1:epochs, (x, y) in data
(_, loss, _, tstate) = Training.single_train_step!(
AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false)
)
push!(losses, Float32(loss))
end
return losses
end
losses = train!(fno, ps, st, data; epochs=1000)
draw(
AoG.data((; losses, iteration=1:length(losses))) *
mapping(:iteration => "Iteration", :losses => "Loss (log10 scale)") *
visual(Lines);
axis=(; yscale=log10),
figure=(; title="Using Fourier Neural Operator to learn the anti-derivative operator")
)

using NeuralOperators, Lux, Random, Optimisers, Reactant
We will use Reactant.jl to accelerate the training process.
xdev = reactant_device()
(::MLDataDevices.ReactantDevice{Missing, Missing, Missing}) (generic function with 1 method)
Constructing training data
First, we construct our training data.
rng = Random.default_rng()
Random.TaskLocalRNG()
batch_size
is the number of observations.
batch_size = 128
128
m
is the length of a single observation, you can also interpret this as the size of the grid we're evaluating our function on.
m = 32
32
We instantiate the domain that the function operates on as a range from 0
to 2π
, whose length is the grid size.
xrange = range(0, 2π; length=m) .|> Float32;
Each value in the array here, α
, will be the multiplicative factor on the input to the sine function.
α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size);
Now, we create our data arrays. We are storing all of the training data in a single array, in order to batch process them more efficiently.
u_data = zeros(Float32, m, 1, batch_size);
v_data = zeros(Float32, m, 1, batch_size);
and fill the data arrays with values. Here, u_data
is
for i in 1:batch_size
u_data[:, 1, i] .= sin.(α[i] .* xrange)
v_data[:, 1, i] .= -inv(α[i]) .* cos.(α[i] .* xrange)
end
Creating the model
Finally, we get to the model itself. We instantiate a FourierNeuralOperator
and provide it several parameters.
The first argument is the "activation function" for each neuron.
The keyword arguments are:
chs
is a tuple, representing the layer sizes for each layer.modes
is a 1-tuple, where the number represents the number of Fourier modes that are preserved, and the size of the tuple represents the number of dimensions.
fno = FourierNeuralOperator(
gelu; # activation function
chs=(1, 64, 64, 128, 1), # channel weights
modes=(16,), # number of Fourier modes to retain
)
FourierNeuralOperator(
model = Chain(
layer_1 = Conv((1,), 1 => 64), # 128 parameters
layer_2 = Chain(
layer_1 = OperatorKernel(
layer = Parallel(
connection = Fix1(add_act, gelu_tanh),
layer_1 = Conv((1,), 64 => 64, use_bias=false), # 4_096 parameters
layer_2 = Chain(
stabilizer = WrappedFunction(Base.Broadcast.BroadcastFunction(identity)),
conv_layer = OperatorConv(64 => 64, FourierTransform{ComplexF32}((16,), shift=false)), # 65_536 parameters
),
),
),
),
layer_3 = Chain(
layer_1 = Conv((1,), 64 => 128, gelu_tanh), # 8_320 parameters
layer_2 = Conv((1,), 128 => 1), # 129 parameters
),
),
) # Total: 78_209 parameters,
# plus 0 states.
Now, we set up the model. This function returns two things, a set of parameters and a set of states. Since the operator is "stateless", the states are empty and will remain so. The parameters are the weights of the neural network, and we will be modifying them in the training loop.
ps, st = Lux.setup(rng, fno) |> xdev;
We construct data as a vector of tuples (input, output). These are pre-batched, but for example if we had a lot of training data, we could dynamically load it, or create multiple batches.
u_data = u_data |> xdev;
v_data = v_data |> xdev;
data = [(u_data, v_data)];
Training the model
Now, we create a function to train the model. An "epoch" is basically a run over all input data, and the more epochs we have, the better the neural network gets!
function train!(model, ps, st, data; epochs=10)
# The `losses` array is used only for visualization,
# you don't actually need it to train.
losses = []
# Initialize a training state and an optimizer (Adam, in this case).
tstate = Training.TrainState(model, ps, st, Adam(0.003f0))
# Loop over epochs, then loop over each batch of training data, and step into the
# training:
for _ in 1:epochs
for (x, y) in data
(_, loss, _, tstate) = Training.single_train_step!(
AutoEnzyme(), MSELoss(), (x, y), tstate; return_gradients=Val(false)
)
push!(losses, Float32(loss))
end
end
return losses, tstate.parameters, tstate.states
end
train! (generic function with 1 method)
Now we train our model!
losses, ps, st = @time train!(fno, ps, st, data; epochs=500)
(Any[1.2563593f0, 0.90409446f0, 0.75699973f0, 0.46846542f0, 0.27422562f0, 0.23793976f0, 0.30596852f0, 0.32843375f0, 0.24045478f0, 0.15608808f0 … 0.00018922474f0, 0.0001889902f0, 0.00018901388f0, 0.00018787259f0, 0.0001862567f0, 0.00018580753f0, 0.00018616868f0, 0.00018565395f0, 0.00018488613f0, 0.00018340204f0], (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.5475705;;; -0.75017637;;; -0.37807664;;; … ;;; -0.14128236;;; -0.403385;;; -0.45076394]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.32573214, 0.3022895, -0.3197351, -0.73144954, -0.14573541, 0.5738895, 0.08689028, -0.6988238, -0.67005813, 0.97262627 … -0.8155921, 0.23888259, -0.043338586, -0.56448644, 0.65007764, -0.2453263, 0.8307237, -0.07527764, -0.908768, -0.2405491])), layer_2 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.023379723 0.065964945 … 0.076738 -0.03586938;;; -0.0297778 -0.08437341 … -0.18763216 0.04119166;;; -0.013440075 -0.13235192 … -0.054850593 -0.04422788;;; … ;;; 0.18017888 0.0042331647 … -0.20152472 0.13377978;;; 0.0035802622 0.053576533 … -0.08239168 0.1254433;;; -0.04151288 0.019680742 … 0.11149646 -0.118386604]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[-0.021934131f0 + 2.0023768f-7im -0.0424336f0 + 1.6258078f-5im … -0.005973052f0 + 3.8480944f-6im -0.026790405f0 + 1.0473344f-5im; -0.012679296f0 + 1.1713211f-5im -0.03184222f0 + 1.2957525f-5im … -0.0057540396f0 + 3.113571f-6im -0.01510022f0 + 2.5894963f-6im; … ; 0.023382938f0 + 6.966589f-6im -0.039712764f0 + 4.390141f-6im … 0.02562763f0 + 1.5898091f-5im 0.022522822f0 + 3.6745623f-6im; 0.017044565f0 + 1.5640622f-5im -0.08616319f0 + 1.0504971f-5im … 0.04075888f0 + 2.6861167f-6im 0.011302553f0 + 1.8418228f-7im;;; 0.0010578645f0 + 0.008173503f0im 0.001088694f0 + 0.00828619f0im … 0.0009854649f0 + 0.008220489f0im 0.0011261824f0 + 0.008248223f0im; 0.004994636f0 - 0.021874882f0im 0.0047570895f0 - 0.021937937f0im … 0.0049126563f0 - 0.021898758f0im 0.0047767702f0 - 0.02193707f0im; … ; -0.0056345095f0 + 0.022034789f0im -0.005214879f0 + 0.02156441f0im … -0.005434951f0 + 0.02184235f0im -0.0053211944f0 + 0.02166892f0im; -0.0016066749f0 - 0.007741334f0im -0.0018503855f0 - 0.007616585f0im … -0.0018431942f0 - 0.007628411f0im -0.001834282f0 - 0.0076426035f0im;;; 0.0060306396f0 + 0.040913455f0im 0.005928658f0 + 0.040620863f0im … 0.0059285345f0 + 0.040737376f0im 0.005921178f0 + 0.040705577f0im; 0.003982953f0 - 0.05619096f0im 0.0037886493f0 - 0.0555669f0im … 0.0039222697f0 - 0.055913728f0im 0.003839012f0 - 0.055746775f0im; … ; -0.050137088f0 + 0.033634152f0im -0.048544582f0 + 0.033131275f0im … -0.04939949f0 + 0.033500712f0im -0.048999157f0 + 0.033235356f0im; 0.01980696f0 + 0.03757678f0im 0.01964577f0 + 0.03755068f0im … 0.01974257f0 + 0.037508667f0im 0.019693766f0 + 0.037522513f0im;;; … ;;; 0.19949539f0 + 0.02269205f0im 0.19540335f0 + 0.022022966f0im … 0.19728665f0 + 0.022372799f0im 0.19634506f0 + 0.022262175f0im; -0.11241422f0 - 0.015110133f0im -0.11037509f0 - 0.014709894f0im … -0.11133609f0 - 0.0149184f0im -0.11074505f0 - 0.014849446f0im; … ; 0.20264105f0 - 0.14149907f0im 0.2022924f0 - 0.14214367f0im … 0.19946949f0 - 0.14004366f0im 0.20042361f0 - 0.14081374f0im; 0.04671992f0 + 0.015525259f0im 0.04685081f0 + 0.015319447f0im … 0.046311006f0 + 0.014891257f0im 0.04649053f0 + 0.015085069f0im;;; 0.19830206f0 - 0.024239589f0im 0.19450776f0 - 0.023574125f0im … 0.19631436f0 - 0.023856884f0im 0.19544163f0 - 0.023723792f0im; -0.12309463f0 - 0.022266857f0im -0.12031829f0 - 0.021531308f0im … -0.12161669f0 - 0.021902625f0im -0.12100527f0 - 0.021771569f0im; … ; 0.23709762f0 + 0.073497474f0im 0.2361705f0 + 0.07157281f0im … 0.23447336f0 + 0.07130495f0im 0.23551121f0 + 0.07182358f0im; 0.061563484f0 - 0.033076756f0im 0.061192777f0 - 0.03293865f0im … 0.060360074f0 - 0.032073095f0im 0.060783744f0 - 0.03252597f0im;;; 0.20825776f0 - 0.1251151f0im 0.20531245f0 - 0.12331691f0im … 0.20678686f0 - 0.12427339f0im 0.20633435f0 - 0.12385817f0im; -0.11652947f0 + 0.018527308f0im -0.11435148f0 + 0.017806428f0im … -0.11537044f0 + 0.018129509f0im -0.11478294f0 + 0.018084506f0im; … ; 0.17309548f0 + 0.077629864f0im 0.17262216f0 + 0.07838359f0im … 0.17095685f0 + 0.07753413f0im 0.17163518f0 + 0.07782968f0im; 0.08450817f0 - 0.15523686f0im 0.083763234f0 - 0.15468846f0im … 0.0834361f0 - 0.15351714f0im 0.08378009f0 - 0.15427832f0im]),))),), layer_3 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.13240194 0.169463 … -0.026417507 -0.05224111;;; 0.14696611 -0.08533586 … -0.11697945 -0.13042495;;; 0.05904611 -0.0975806 … -0.1377364 -0.17645033;;; … ;;; 0.23790939 0.0881269 … -0.18012255 -0.18658745;;; -0.16654582 0.17060001 … 0.027283419 -0.07382685;;; -0.14449495 0.12375862 … -0.16142659 -0.10717205]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.052595496, -0.079864986, 0.049376357, -0.039686598, 0.018115014, 0.010044209, 0.049358737, -0.046467207, -0.07993954, -0.02044194 … 0.025258616, -0.040614624, 0.0069566113, -0.09788293, -0.11799793, 0.08962666, 0.014088821, -0.08383356, -0.11321426, -0.06794172])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.14386532 0.10532405 … -0.1849379 -0.0823957;;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.027839204])))), (layer_1 = NamedTuple(), layer_2 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())),), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))
Applying the model
Let's try to actually apply this model using some input data.
input_data = u_data[:, 1, 1]
32-element Reactant.ConcretePJRTArray{Float32,1}:
0.0
0.11259386
0.22375578
0.33207196
0.436165
0.5347109
0.62645644
0.7102349
0.7849807
0.84974325
⋮
0.51964104
0.4201385
0.31529233
0.20643656
0.09495513
-0.017733686
-0.13019721
-0.2410049
-0.3487473
This is our input data. It's currently one-dimensional, but our neural network expects input in batched form, so we simply reshape
it (a no-cost operation) to a 3d array with singleton dimensions.
reshaped_input = reshape(input_data, length(input_data), 1, 1)
32×1×1 reshape(::Reactant.ConcretePJRTArray{Float32,1}, 32, 1, 1) with eltype Float32:
[:, :, 1] =
0.0
0.11259386
0.22375578
0.33207196
0.436165
0.5347109
0.62645644
0.7102349
0.7849807
0.84974325
⋮
0.51964104
0.4201385
0.31529233
0.20643656
0.09495513
-0.017733686
-0.13019721
-0.2410049
-0.3487473
Now we can pass this to Lux.apply
(@jit
is used to run the function with Reactant.jl):
output_data, st = @jit Lux.apply(fno, reshaped_input, ps, st)
(Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-1.7649521; -1.7795782; … ; 1.7657027; 1.6976722;;;]), (layer_1 = NamedTuple(), layer_2 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())),), layer_3 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))
and plot it:
using CairoMakie, AlgebraOfGraphics
const AoG = AlgebraOfGraphics
AoG.set_aog_theme!()
f, a, p = lines(dropdims(Array(reshaped_input); dims=(2, 3)); label="u")
lines!(a, dropdims(Array(output_data); dims=(2, 3)); label="Predicted")
lines!(a, Array(v_data)[:, 1, 1]; label="Expected")
axislegend(a)
# Compute the absolute error and plot that too,
# on a separate axis.
absolute_error = Array(v_data)[:, 1, 1] .- dropdims(Array(output_data); dims=(2, 3))
a2, p2 = lines(f[2, 1], absolute_error; axis=(; ylabel="Error"))
rowsize!(f.layout, 2, Aspect(1, 1 / 8))
linkxaxes!(a, a2)
f
