Burgers Equation using Fourier Neural Operator

Data Loading

using DataDeps, MAT, MLUtils
using PythonCall, CondaPkg # For `gdown`
using Printf

const gdown = pyimport("gdown")

register(
    DataDep(
        "Burgers",
        """
        Burgers' equation dataset from
        [fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)

        mapping between initial conditions to the solutions at the last point of time \
        evolution in some function space.

        u(x,0) -> u(x, time_end):

          * `a`: initial conditions u(x,0)
          * `u`: solutions u(x,t_end)
        """,
        "https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe",
        "9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd";
        fetch_method=(url, local_dir) -> begin
            pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip")))
        end,
        post_fetch_method=unpack,
    ),
)

filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat")

const N = 2048
const Δsamples = 2^3
const grid_size = div(2^13, Δsamples)
const T = Float32

file = matopen(filepath)
x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :)
y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :)
close(file)

x_data = reshape(permutedims(x_data, (2, 1)), grid_size, 1, N);
y_data = reshape(permutedims(y_data, (2, 1)), grid_size, 1, N);
1024×1×2048 Array{Float32, 3}:
[:, :, 1] =
 0.6834662
 0.6850582
 0.68663555
 0.68819803
 0.6897452
 0.6912767
 0.6927921
 0.6942911
 0.6957732
 0.697238
 0.69868517
 0.70011413
 0.70152456
 0.7029159
 0.7042878
 0.7056398
 0.70697135
 0.70828193
 0.7095711
 0.7108384
 0.71208316
 0.713305
 0.71450317
 0.7156773
 0.7168267
 0.71795076
 0.719049
 0.7201206
 0.721165
 0.7221816
 0.7231697
 0.7241285
 0.72505736
 0.72595555
 0.72682226
 0.7276568
 0.7284582
 0.7292259
 0.7299588
 0.7306561
 0.731317
 0.7319404
 0.73252547
 0.73307127
 0.73357666
 0.7340406
 0.73446214
 0.7348401
 ⋮
 0.5978408
 0.59978867
 0.6017345
 0.60367817
 0.6056194
 0.6075579
 0.6094936
 0.6114262
 0.6133555
 0.6152813
 0.61720335
 0.61912143
 0.62103534
 0.62294483
 0.62484974
 0.62674975
 0.62864465
 0.63053423
 0.6324182
 0.6342964
 0.6361686
 0.63803446
 0.6398938
 0.6417463
 0.6435918
 0.64542997
 0.6472606
 0.64908344
 0.6508982
 0.6527046
 0.6545024
 0.6562913
 0.65807104
 0.6598413
 0.66160184
 0.6633524
 0.66509265
 0.66682225
 0.66854095
 0.6702484
 0.6719444
 0.6736285
 0.6753004
 0.6769599
 0.67860657
 0.68024004
 0.68186

[:, :, 2] =
 0.7511075
 0.75031614
 0.74948686
 0.74862
 0.7477157
 0.7467743
 0.7457962
 0.7447817
 0.74373096
 0.7426446
 0.74152285
 0.74036616
 0.7391749
 0.7379496
 0.73669064
 0.73539853
 0.7340737
 0.7327167
 0.731328
 0.7299083
 0.728458
 0.72697765
 0.725468
 0.7239295
 0.72236294
 0.7207688
 0.7191478
 0.7175006
 0.7158279
 0.7141304
 0.7124088
 0.7106638
 0.70889616
 0.7071066
 0.7052959
 0.70346487
 0.70161426
 0.6997448
 0.69785744
 0.69595283
 0.69403183
 0.6920953
 0.690144
 0.6881789
 0.6862007
 0.68421024
 0.6822084
 0.6801961
 ⋮
 0.74433285
 0.7453439
 0.7463201
 0.7472613
 0.74816704
 0.74903715
 0.74987143
 0.7506696
 0.7514314
 0.7521566
 0.75284505
 0.7534964
 0.7541106
 0.75468725
 0.75522625
 0.7557274
 0.7561905
 0.7566154
 0.7570019
 0.7573498
 0.757659
 0.7579294
 0.75816077
 0.75835305
 0.75850606
 0.7586197
 0.758694
 0.7587287
 0.7587238
 0.7586793
 0.75859505
 0.75847113
 0.7583074
 0.7581039
 0.75786066
 0.7575776
 0.7572549
 0.7568925
 0.75649047
 0.756049
 0.75556797
 0.7550476
 0.7544881
 0.75388944
 0.7532519
 0.75257564
 0.75186074

[:, :, 3] =
 -0.41533458
 -0.4158413
 -0.41631535
 -0.416757
 -0.4171665
 -0.41754422
 -0.41789043
 -0.41820538
 -0.4184894
 -0.41874272
 -0.41896564
 -0.4191584
 -0.4193213
 -0.41945457
 -0.41955847
 -0.41963327
 -0.41967916
 -0.41969648
 -0.41968536
 -0.4196461
 -0.41957894
 -0.41948405
 -0.4193617
 -0.41921213
 -0.4190355
 -0.4188321
 -0.41860205
 -0.41834566
 -0.41806307
 -0.4177545
 -0.41742018
 -0.4170603
 -0.416675
 -0.41626456
 -0.41582915
 -0.41536897
 -0.41488418
 -0.414375
 -0.4138416
 -0.41328418
 -0.41270295
 -0.41209805
 -0.41146967
 -0.41081804
 -0.4101433
 -0.4094456
 -0.4087252
 -0.40798223
  ⋮
 -0.34818122
 -0.35062543
 -0.3530205
 -0.35536665
 -0.3576641
 -0.3599132
 -0.36211416
 -0.36426732
 -0.36637297
 -0.36843148
 -0.3704431
 -0.37240824
 -0.37432718
 -0.37620035
 -0.37802806
 -0.3798107
 -0.3815486
 -0.3832422
 -0.38489184
 -0.3864979
 -0.3880608
 -0.3895809
 -0.3910586
 -0.39249426
 -0.39388832
 -0.39524114
 -0.3965531
 -0.3978246
 -0.39905605
 -0.40024778
 -0.4014002
 -0.40251374
 -0.4035887
 -0.40462554
 -0.40562454
 -0.40658614
 -0.40751064
 -0.40839848
 -0.40924993
 -0.4100654
 -0.41084528
 -0.4115898
 -0.4122994
 -0.41297436
 -0.41361505
 -0.41422176
 -0.41479483

;;; … 

[:, :, 2046] =
  0.016046494
  0.008860671
  0.0016381989
 -0.005616474
 -0.012898825
 -0.020204268
 -0.02752816
 -0.034865808
 -0.04221249
 -0.04956345
 -0.056913916
 -0.06425911
 -0.07159427
 -0.07891461
 -0.08621542
 -0.09349197
 -0.1007396
 -0.107953705
 -0.11512973
 -0.12226319
 -0.1293497
 -0.13638492
 -0.14336464
 -0.15028474
 -0.15714122
 -0.1639302
 -0.17064787
 -0.17729065
 -0.18385498
 -0.19033752
 -0.19673505
 -0.20304449
 -0.2092629
 -0.2153875
 -0.22141568
 -0.22734497
 -0.23317306
 -0.2388978
 -0.24451719
 -0.25002939
 -0.25543272
 -0.26072565
 -0.2659068
 -0.27097496
 -0.27592903
 -0.28076813
 -0.2854914
 -0.29009825
  ⋮
  0.25936982
  0.25655848
  0.25364703
  0.25063428
  0.2475191
  0.24430047
  0.24097733
  0.23754878
  0.23401396
  0.23037204
  0.22662233
  0.22276421
  0.21879712
  0.21472062
  0.21053438
  0.20623814
  0.20183176
  0.19731523
  0.19268864
  0.1879522
  0.18310627
  0.1781513
  0.17308788
  0.16791677
  0.16263887
  0.15725517
  0.15176687
  0.14617528
  0.14048186
  0.13468826
  0.12879623
  0.12280775
  0.11672488
  0.110549875
  0.10428515
  0.097933255
  0.091496915
  0.08497899
  0.07838251
  0.071710624
  0.06496665
  0.05815404
  0.051276375
  0.044337366
  0.03734086
  0.030290816
  0.0231913

[:, :, 2047] =
 0.36310026
 0.36462003
 0.36611378
 0.36758128
 0.36902225
 0.3704365
 0.37182376
 0.37318385
 0.37451655
 0.37582165
 0.3770989
 0.37834814
 0.37956914
 0.38076174
 0.3819257
 0.3830609
 0.38416716
 0.38524425
 0.38629207
 0.38731045
 0.38829923
 0.3892583
 0.39018747
 0.39108667
 0.39195576
 0.39279464
 0.3936032
 0.39438137
 0.39512908
 0.39584625
 0.3965328
 0.3971887
 0.3978139
 0.39840838
 0.39897212
 0.39950514
 0.40000743
 0.40047896
 0.40091982
 0.40133002
 0.40170965
 0.40205875
 0.40237737
 0.40266564
 0.40292367
 0.40315154
 0.40334943
 0.40351745
 ⋮
 0.26702473
 0.2694785
 0.27191815
 0.27434352
 0.27675432
 0.27915034
 0.28153133
 0.28389704
 0.28624728
 0.28858176
 0.29090023
 0.2932025
 0.29548824
 0.29775727
 0.30000934
 0.30224413
 0.30446145
 0.30666104
 0.30884263
 0.31100595
 0.31315073
 0.31527677
 0.31738377
 0.31947145
 0.32153958
 0.32358786
 0.3256161
 0.32762393
 0.32961118
 0.33157754
 0.33352274
 0.33544654
 0.33734864
 0.3392288
 0.34108678
 0.34292224
 0.344735
 0.34652475
 0.34829122
 0.35003418
 0.35175335
 0.3534485
 0.35511935
 0.35676563
 0.3583871
 0.35998356
 0.36155468

[:, :, 2048] =
 1.0363716
 1.0398366
 1.0432957
 1.0467485
 1.0501951
 1.0536352
 1.0570688
 1.0604957
 1.0639158
 1.067329
 1.0707351
 1.074134
 1.0775255
 1.0809096
 1.084286
 1.0876546
 1.0910151
 1.0943676
 1.0977118
 1.1010476
 1.1043746
 1.107693
 1.1110022
 1.1143023
 1.1175929
 1.1208739
 1.1241452
 1.1274062
 1.1306571
 1.1338974
 1.137127
 1.1403457
 1.143553
 1.1467489
 1.1499329
 1.1531048
 1.1562642
 1.159411
 1.1625447
 1.1656651
 1.1687717
 1.1718643
 1.1749424
 1.1780056
 1.1810535
 1.1840858
 1.187102
 1.1901015
 ⋮
 0.86836725
 0.8720194
 0.8756691
 0.8793164
 0.8829611
 0.88660324
 0.8902427
 0.89387953
 0.89751357
 0.9011448
 0.9047731
 0.90839857
 0.912021
 0.91564035
 0.91925657
 0.9228697
 0.92647946
 0.930086
 0.9336891
 0.9372888
 0.94088495
 0.94447756
 0.9480665
 0.95165163
 0.95523304
 0.95881057
 0.96238416
 0.9659537
 0.9695192
 0.97308046
 0.9766375
 0.98019016
 0.9837384
 0.9872822
 0.99082136
 0.9943559
 0.99788564
 1.0014105
 1.0049305
 1.0084454
 1.0119551
 1.0154597
 1.0189589
 1.0224527
 1.025941
 1.0294236
 1.0329006

Model

using Lux, NeuralOperators, Optimisers,  Random, Reactant

const cdev = cpu_device()
const xdev = reactant_device(; force=true)

fno = FourierNeuralOperator(
    (16,), 2, 1, 32; activation=gelu, stabilizer=tanh
)
ps, st = Lux.setup(Random.default_rng(), fno) |> xdev;
((positional_embedding = NamedTuple(), lifting = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.45048642 0.8215421;;; -0.6469188 0.5172023;;; -0.50672746 -0.2251062;;; … ;;; 0.56619805 -0.24967276;;; 0.9645864 1.0885009;;; -0.4685063 1.1935192]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.159703, -0.3509462, -0.10126159, 0.39295563, 0.6632423, -0.49430946, -0.21302897, 0.3522459, -0.08633116, -0.12322027  …  0.55145764, -0.27602097, 0.05406876, -0.5174525, -0.5605669, -0.39052242, -0.57962865, 0.31826562, 0.29890722, -0.2686216])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.11259408 0.07066951 … 0.20654172 0.16918825;;; -0.15999952 0.115958184 … 0.08281645 -0.024991611;;; 0.16204435 0.074184105 … -0.21040425 -0.051126767;;; … ;;; 0.15478508 0.19775072 … -0.027014513 -0.041541196;;; 0.20725302 -0.07181448 … 0.19796564 -0.02747813;;; 0.105362296 -0.019454014 … -0.07874177 -0.15154172]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.07033114, 0.07001881, 0.053533047, -0.085907385, -0.05988024, -0.104506835, 0.10887393, 0.09060454, 0.054366186, 0.029574394  …  -0.11592382, 0.05252181, -0.026414692, 0.03713721, 0.07035078, 0.1113888, 0.024390548, -0.053051412, 0.006304696, 0.03983806]))), fno_blocks = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.15519148 0.03921681 … -0.12586261 -0.04448573;;; 0.12898631 -0.10303989 … -0.10187342 -0.27491853;;; 0.02485181 -0.1949206 … -0.1056301 0.15868787;;; … ;;; -0.2679895 0.18788193 … 0.2659741 0.14570174;;; 0.29236 0.304122 … 0.084096484 0.1277344;;; -0.28771406 0.18056881 … -0.09567626 0.14510007]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[1.6706108f-6 + 3.7705686f-6im 5.393684f-5 + 5.405541f-5im … -5.6601595f-5 + 1.6648817f-5im 4.5352645f-6 - 0.00011626167f0im; 2.862091f-5 - 4.164547f-5im 6.337899f-5 + 2.637622f-5im … 4.6356254f-5 - 4.597543f-5im 6.231068f-5 + 7.540344f-6im; … ; 5.0573588f-5 - 2.8344606f-5im -4.3950982f-5 + 7.612103f-5im … -1.1755968f-5 - 1.8446735f-6im 8.267337f-5 - 1.4497426f-5im; -3.1217132f-6 - 4.641679f-5im 9.024399f-6 + 5.1104173f-5im … -6.281353f-5 - 4.097029f-5im 4.883649f-5 - 1.0310134f-5im;;; 4.067016f-5 + 4.6861256f-5im 5.3112322f-5 + 2.088728f-5im … 3.1648378f-6 + 5.6358316f-5im -5.4386008f-5 - 3.1795455f-5im; 3.5243887f-5 + 3.7235637f-5im -3.2949734f-5 - 3.5800193f-5im … -5.9887505f-5 - 3.5702877f-5im -4.1863525f-5 - 2.3692563f-5im; … ; 6.912688f-5 - 2.210999f-5im 4.446139f-5 + 8.5168795f-6im … 6.2987885f-5 + 4.1943254f-5im 1.4032827f-5 - 8.796392f-5im; -7.19161f-5 + 1.0448595f-5im -5.253045f-5 - 2.7749134f-5im … 1.0235315f-5 + 6.2976986f-5im -9.4385505f-5 + 1.1969496f-5im;;; 8.3935905f-5 + 3.2144402f-5im 6.890841f-6 - 6.885703f-5im … -3.4504912f-5 - 1.483575f-6im -8.80132f-5 - 1.5442638f-5im; 4.2643573f-5 + 7.7698525f-5im -2.3793255f-5 - 9.143041f-5im … 3.699751f-5 - 1.5827543f-5im 2.767871f-5 - 5.1062605f-5im; … ; 5.1341958f-5 - 4.6251087f-5im 7.125769f-5 + 4.257042f-5im … -1.0617194f-5 - 8.4749394f-5im -9.078296f-5 - 2.4763845f-5im; 4.392465f-5 - 2.1993168f-5im -7.678762f-6 + 2.0259955f-5im … 4.7975336f-5 - 4.2107218f-5im -2.7629132f-5 + 2.4924178f-5im;;; … ;;; 4.8837966f-5 - 6.969894f-5im -4.2340835f-6 - 0.000104790684f0im … 4.47099f-5 + 3.5927515f-6im 4.099971f-5 - 2.5889727f-5im; 7.722521f-5 - 2.335808f-5im 2.1036707f-5 - 6.242976f-5im … -2.2125707f-5 + 4.0439438f-5im -7.802892f-5 - 3.6993944f-5im; … ; 2.9934658f-5 - 7.948678f-5im -4.263506f-5 - 7.528477f-5im … 2.556599f-5 - 4.1384956f-5im -7.529857f-5 - 1.76541f-5im; -7.943425f-6 - 0.00010962754f0im 8.18513f-5 - 3.109182f-5im … 6.5086424f-6 + 5.241695f-5im -8.856964f-5 - 3.145697f-5im;;; 0.000100476216f0 + 1.5000318f-5im 9.4583025f-5 + 2.4858236f-5im … -7.4586096f-5 - 4.0343926f-5im -1.2343749f-5 - 1.7625076f-5im; -6.634201f-5 - 5.4357646f-5im -1.4078752f-5 + 9.528179f-6im … -4.9549635f-6 + 3.409622f-5im -1.0734155f-5 + 5.438907f-5im; … ; -6.5526547f-6 - 0.0001052922f0im 4.068148f-5 - 5.2158663f-5im … 3.4408673f-5 - 8.155139f-5im -6.4049535f-5 - 2.0215586f-5im; 4.1494786f-5 + 2.4782115f-5im 4.4294313f-5 + 6.355632f-5im … 1.06440275f-5 - 3.8804545f-5im -4.2079046f-7 + 7.769853f-5im;;; -2.8860217f-5 - 3.465412f-5im 8.955141f-5 + 4.2313914f-6im … -4.0177067f-5 + 4.196704f-5im 3.345056f-6 - 0.000113107824f0im; -7.778642f-5 - 4.989706f-7im -3.0086718f-5 - 1.3005316f-5im … -8.830706f-6 + 4.730277f-5im 7.641059f-5 + 3.0003139f-5im; … ; -4.0510437f-5 - 7.608402f-5im -9.651915f-5 + 2.6000416f-6im … 5.6149947f-6 + 9.620583f-5im 7.7865305f-5 - 1.9123298f-5im; -2.7168651f-5 - 9.119702f-5im -9.407774f-5 + 2.0953885f-6im … -4.9984563f-5 - 6.707529f-5im 9.547421f-5 + 8.83879f-6im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.11306852 -0.09494071 … 0.00458673 0.2769081;;; -0.021850213 -0.052134648 … 0.1691181 -0.19033004;;; -0.18775013 0.28516403 … 0.24153912 0.2923376;;; … ;;; 0.08753696 -0.0086831525 … 0.12428314 0.15478724;;; -0.03052154 -0.15229201 … 0.28737718 -0.08045661;;; 0.27243385 0.29067954 … 0.18422611 -0.068228625]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.0031418367, 0.10188453, -0.12537697, 0.06675587, -0.013257144, -0.13640673, 0.15879893, 0.054921497, -0.1348415, 0.16932376, -0.12173626, 0.06348767, -0.03423934, -0.082781404, 0.13630433, -0.108549275])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.22665338 -0.3400378 … 0.08706145 0.4081696;;; 0.28968003 -0.08522588 … 0.038577355 -0.041299798;;; 0.19068548 -0.28463152 … 0.33406883 0.04204296;;; … ;;; -0.3132675 0.24250282 … -0.32430583 0.1269693;;; 0.22419181 0.35890624 … 0.18207164 -0.17319195;;; 0.22345763 0.38342932 … 0.2773201 -0.09895574]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.15528032, -0.12397516, 0.08258507, 0.008148372, 0.09858009, -0.004576862, 0.13227305, -0.2141046, 0.21508905, -0.23343095  …  0.17101729, 0.044703335, -0.23205146, 0.22060686, -0.1241993, 0.09057075, -0.002218455, -0.16072816, -0.021555781, -0.18396485])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.26722625 0.09803289 … 0.23878174 0.24928363]),)), layer_2 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.08004587 -0.2757149 … 0.05531754 0.107396856;;; -0.15095755 0.16446148 … 0.15787834 0.10056594;;; -0.12946928 0.30604163 … 0.018036047 -0.078015506;;; … ;;; 0.2436746 -0.11337698 … -0.09482865 -0.0037036065;;; -0.21400957 0.17087924 … -0.024958976 -0.2102636;;; 0.024612296 -0.24273801 … -0.021588214 0.058255482]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[-6.944917f-5 + 1.9736479f-5im 1.100923f-5 - 7.263282f-5im … 4.892848f-5 - 5.77065f-5im 3.1812597f-6 + 5.6345794f-5im; -2.1512416f-5 - 7.467682f-5im 8.529729f-6 - 3.2645527f-5im … -2.5659378f-5 - 6.368989f-5im -5.6208548f-5 + 3.708458f-5im; … ; 4.167106f-5 + 6.52251f-5im 1.4755933f-6 + 0.00011084194f0im … -3.392622f-5 - 4.461022f-5im 1.684027f-5 + 5.7103847f-5im; -2.5454276f-5 - 6.556705f-5im -2.1250962f-5 + 8.60897f-5im … 5.2762276f-5 - 3.165503f-5im 4.685324f-6 - 5.2981966f-5im;;; -9.145959f-5 + 3.0509764f-5im 2.9528528f-5 + 8.036611f-5im … 1.2142962f-5 - 4.471757f-5im -3.213906f-5 - 4.1843603f-5im; -1.7887585f-5 + 3.0196657f-5im 1.0977812f-5 - 4.0769773f-5im … -0.0001118833f0 - 6.5970526f-6im -6.9403985f-5 + 8.4880085f-6im; … ; -4.4718734f-5 - 1.3996221f-5im 3.1580166f-5 - 2.8086048f-5im … 3.6186539f-6 - 3.0274678f-6im -0.00011013642f0 + 9.86649f-7im; 5.3858108f-5 - 1.0358344f-6im -9.969797f-5 + 1.311252f-5im … 5.0294708f-5 - 1.9420258f-5im -1.6821134f-5 - 9.639828f-5im;;; -7.424115f-5 + 3.778863f-5im -7.055879f-5 - 2.5559959f-5im … 1.0271251f-6 - 4.866036f-5im 4.6190078f-5 - 4.0562518f-5im; 9.8097764f-5 - 8.1190956f-7im 2.1967346f-5 - 1.7651364f-5im … -8.8720844f-7 - 1.602708f-5im 4.256406f-6 + 6.1615356f-5im; … ; -2.622712f-5 - 7.103561f-6im 2.4150067f-6 - 5.8742953f-5im … 2.1835142f-5 + 2.6665657f-7im 9.4419665f-5 - 2.4973546f-5im; -1.6155216f-5 - 8.3839055f-5im -1.2156954f-5 - 6.2221014f-5im … 3.0401636f-5 - 3.0999967f-5im -5.4792974f-5 + 1.0731754f-5im;;; … ;;; -2.9812618f-5 - 2.2633736f-5im 0.0001046964f0 + 1.51716085f-5im … -1.4597063f-5 + 5.918934f-5im -5.440328f-6 - 6.4585547f-6im; -7.352892f-5 - 4.8792135f-6im 5.1517854f-6 - 5.8969395f-5im … 6.47573f-5 + 1.2595454f-5im -3.0900854f-5 + 1.9970386f-5im; … ; -1.5903555f-5 + 5.9727012f-5im 1.5637932f-5 + 5.170123f-5im … 1.3866971f-5 + 2.208515f-5im -8.954601f-6 + 4.614864f-5im; -8.066736f-5 + 4.034399f-5im 2.9015537f-5 - 4.5085944f-5im … -1.7889834f-5 - 9.325097f-5im -0.00011028525f0 - 7.174116f-6im;;; -7.428425f-5 - 4.179818f-5im -2.0538653f-5 - 2.777356f-5im … -0.00011137294f0 + 9.2450355f-6im -2.3159453f-5 - 9.7808625f-5im; 9.113123f-5 - 1.1671502f-5im -3.3895078f-5 + 1.8190782f-5im … -6.801301f-5 + 6.7740402f-6im -5.4041448f-6 - 0.00010489095f0im; … ; -7.238901f-5 - 4.187695f-5im -0.00010864493f0 - 5.194488f-6im … 7.316885f-6 + 2.087767f-5im -1.3714547f-5 - 5.94768f-5im; -8.087518f-6 + 5.403768f-5im -2.006667f-5 - 4.474941f-5im … -4.4717308f-5 - 5.766889f-5im 4.5228902f-5 - 3.321001f-5im;;; 8.740259f-6 - 6.448863f-5im 4.0860257f-5 + 5.829089f-5im … -1.7944229f-5 - 2.4127672f-5im -9.2310074f-8 + 1.8880943f-5im; -4.6131398f-5 + 7.192234f-5im -0.000108625354f0 + 7.0528986f-6im … -2.4623252f-5 + 8.8807225f-5im 9.830481f-6 - 2.5502544f-5im; … ; -7.030695f-5 - 3.4626515f-5im 0.00011120639f0 + 3.884197f-6im … -6.2605155f-5 + 5.1724048f-5im -1.0957519f-5 + 2.2385648f-5im; -8.973876f-5 - 7.745657f-6im -6.805112f-5 + 1.1435353f-5im … -6.841988f-5 + 3.8024664f-6im 9.946275f-5 - 1.3466561f-5im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.15283357 0.23583993 … -0.16798583 0.005778974;;; -0.30492485 0.08329239 … -0.027095117 0.1116096;;; -0.20075685 -0.1971108 … -0.045551717 0.17903098;;; … ;;; -0.08855032 -0.17230253 … -0.19003019 0.017790072;;; -0.14727323 0.24199045 … -0.007359106 -0.26266348;;; -0.30109015 0.05572032 … 0.054540124 0.17849751]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.17645743, 0.096673004, 0.11225674, 0.11067301, 0.17134006, 0.123067744, -0.10854124, 0.017893909, -0.13743314, -0.09237544, -0.017989434, 0.09168463, 0.015791897, -0.16021779, 0.11211679, -0.122612745])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.09433882 0.42526373 … 0.33682358 -0.31651193;;; -0.15883683 0.038021676 … 0.057269365 -0.35315865;;; 0.40136573 -0.103728294 … -0.28516972 -0.2486246;;; … ;;; -0.30100456 -0.34606373 … 0.20537008 -0.290084;;; 0.31450155 -0.37467518 … -0.13156639 0.22070329;;; 0.22435245 0.15107821 … 0.14122795 0.1747459]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.0072006583, 0.14925969, 0.0390113, -0.10223663, 0.2381213, -0.1569022, 0.115956366, -0.13106886, -0.23001942, 0.059086323  …  0.06210494, 0.13298455, 0.1716491, -0.21588913, 0.03526342, 0.105377704, -0.18904746, 0.0028511584, -0.1472426, -0.24608746])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.28504488 0.27378538 … 0.19929509 0.17291225]),)), layer_3 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.20224383 0.30429968 … 0.11955541 0.23232342;;; -0.031039223 0.118525155 … -0.2963487 -0.27195686;;; -0.16330275 0.105522566 … -0.30596742 -0.031536978;;; … ;;; -0.16481528 0.23382318 … 0.26429492 0.29160932;;; 0.090146 0.012240174 … -0.15922567 -0.28099555;;; 0.15669882 -0.29040548 … -0.27592134 0.07088909]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[0.000110564826f0 - 5.454509f-6im 6.8007255f-5 - 4.482572f-5im … 9.805285f-5 + 2.0983738f-5im 6.4375236f-5 - 4.5411325f-5im; -5.99966f-6 + 7.274222f-5im -3.9460792f-5 - 8.252176f-5im … 5.5667988f-5 - 3.0002047f-5im -2.5727059f-6 + 2.7192873f-5im; … ; 4.553733f-5 + 5.1999086f-7im -8.190516f-6 + 0.000103613274f0im … -7.285188f-5 - 8.910996f-6im 5.6809688f-5 + 2.8335111f-5im; -1.7999693f-5 + 2.690676f-5im 6.407429f-5 + 8.353025f-6im … 2.0611325f-5 + 4.370806f-6im -1.033298f-5 + 8.7410735f-5im;;; -5.3302574f-6 - 0.00011619046f0im -5.1635798f-6 - 7.52321f-5im … -6.758214f-5 + 1.0702897f-5im 2.5685411f-5 - 8.625889f-5im; 9.2755356f-5 - 2.6642847f-5im -4.120909f-5 - 2.4502107f-5im … 7.367511f-6 - 8.628042f-5im 3.270516f-5 - 2.8237388f-5im; … ; 4.7264097f-5 - 9.324198f-6im 2.8225666f-5 + 2.0643245f-5im … 6.1808096f-6 + 2.5297937f-5im 3.7629652f-6 - 9.44284f-6im; 1.737709f-5 + 2.645469f-5im 2.4473164f-5 - 5.6456483f-6im … -4.5040848f-5 - 1.3286866f-5im -5.5510995f-5 - 2.767775f-5im;;; 4.751976f-5 + 4.64248f-5im -7.800649f-5 + 4.2124957f-5im … 3.419249f-6 + 0.000110278284f0im -3.7994323f-7 - 0.00010614693f0im; 1.1264485f-5 + 9.3096874f-5im 5.5009223f-6 - 9.3283095f-5im … -7.6077777f-6 + 0.00010306079f0im -1.8614897f-5 + 9.4745425f-5im; … ; 5.43923f-5 - 5.1435127f-6im -1.4985562f-5 - 2.5298388f-5im … -5.598792f-5 + 3.6237907f-7im 1.7029379f-7 + 8.274646f-5im; -2.0616892f-5 + 3.2566764f-5im 2.8910275f-5 + 6.002275f-5im … 2.3169465f-5 - 3.3153796f-5im 7.020317f-6 + 5.9708975f-5im;;; … ;;; -8.149517f-5 - 7.5255157f-6im 7.801067f-5 + 3.5232435f-5im … -4.239713f-5 - 2.9810566f-5im 2.7750888f-5 + 8.308316f-5im; 9.205469f-6 + 1.6117032f-5im 1.3856625f-6 + 5.4089833f-5im … 7.382971f-5 - 7.198978f-6im 5.692315f-5 + 5.2508403f-6im; … ; -8.808769f-6 + 2.7394446f-5im -4.8174443f-5 + 6.780092f-5im … 0.000107043255f0 - 5.4330303f-7im 7.5891876f-6 + 5.503303f-5im; 4.0324638f-5 - 5.3116994f-5im -2.9108072f-5 + 3.8146398f-5im … -5.876609f-5 - 5.7509722f-5im 0.00011674235f0 - 4.0791056f-6im;;; -3.4542936f-5 - 5.704088f-5im 9.698466f-6 - 6.894921f-5im … -5.406275f-5 - 6.186083f-5im -8.450297f-6 + 7.340987f-5im; -4.4177803f-5 + 7.087823f-6im -2.472286f-5 - 1.66237f-5im … -2.5827037f-5 + 8.171716f-5im 6.8489564f-5 - 2.6545807f-5im; … ; 3.7620484f-6 - 3.772964f-5im -4.5426503f-5 - 7.17048f-5im … 5.178241f-6 + 5.938711f-5im -5.5727818f-5 + 3.0303454f-5im; -3.896108f-5 + 8.126517f-5im -8.541544f-5 - 2.42667f-6im … 1.5115911f-5 - 3.8225866f-5im -8.545947f-6 - 9.2782895f-5im;;; 4.9035312f-5 + 4.0607978f-5im -4.2508327f-7 - 8.378937f-5im … -5.6288744f-5 + 6.292472f-5im -2.7021488f-5 - 6.83504f-5im; -3.231327f-5 - 8.942472f-6im -3.252503f-5 + 4.4713233f-5im … -5.9218568f-5 + 6.903283f-7im -9.047391f-5 - 2.1144093f-5im; … ; 2.273421f-5 + 2.1825253f-5im -5.030379f-6 - 7.719529f-6im … -4.8301226f-6 - 8.000287f-5im 3.0207448f-6 - 5.5184733f-5im; 4.6934467f-5 + 6.223717f-5im 9.089505f-6 + 5.284071f-5im … -1.8044091f-5 - 7.769596f-5im 7.343585f-5 - 3.1110176f-5im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.03983848 0.19607632 … -0.008741114 0.29659045;;; 0.19948426 0.05835575 … 0.18548733 -0.025122898;;; 0.122004874 0.05456425 … 0.2369841 0.20411663;;; … ;;; -0.2062294 -0.29174805 … -0.12341856 0.09417242;;; 0.265215 0.24679036 … 0.24224252 -0.2040529;;; -0.25798193 0.30339515 … -0.10822388 -0.04247606]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.17644188, -0.015626429, 0.0051075025, -0.13385877, 0.1387758, -0.054379024, -0.13689122, 0.08329212, -0.120442666, 0.13259022, -0.1731755, -0.1393216, -0.020477047, -0.07313665, -0.122828096, -0.07236175])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.116968036 -0.41951877 … 0.014586019 0.2112655;;; -0.2294085 -0.38947496 … 0.07447402 -0.4047335;;; -0.2931729 0.02447795 … -0.24075744 0.3872833;;; … ;;; -0.3733018 0.24538384 … 0.12744063 0.4234627;;; 0.23478365 0.00589124 … -0.1704055 0.31815884;;; 0.08882342 0.36547762 … 0.27109545 -0.2221812]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.034669995, 0.09541288, 0.127505, -0.24526429, 0.23563218, -0.16610855, 0.17962897, -0.13025218, 0.08740431, 0.14996877  …  0.1947349, -0.16076627, -0.17180067, 0.049119025, -0.09261748, -0.029668152, 0.07338387, -0.015861958, -0.05979511, -0.027300924])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.30119717 -0.13142028 … -0.3045292 -0.3208088]),)), layer_4 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.24239708 -0.12292884 … -0.26763082 -0.15369464;;; 0.051489104 -0.18802202 … -0.19783135 0.3035818;;; -0.14299795 -0.18846948 … 0.10985664 -0.072139956;;; … ;;; 0.27763632 -0.1436563 … -0.2264389 -0.13123137;;; -0.18255487 -0.2677296 … -0.04587657 0.15047394;;; 0.18303156 -0.058275923 … 0.05013783 0.053699046]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[3.6989957f-5 + 7.3359544f-5im -6.2819374f-5 - 1.2109311f-5im … 6.4810876f-5 + 1.1637727f-5im -1.3403987f-6 - 8.33921f-5im; -2.3674904f-5 + 5.0757968f-5im -4.910922f-5 + 5.314623f-5im … -9.470064f-5 - 1.7860853f-5im -6.984829f-5 + 5.071309f-5im; … ; 5.5318073f-5 + 4.5203487f-5im 2.5687077f-5 + 2.704918f-5im … 4.754956f-5 + 3.0287527f-5im -3.398097f-5 + 3.095447f-6im; -3.435991f-5 - 1.5760408f-5im 3.2983604f-5 - 7.135105f-5im … 2.6372509f-5 + 5.422435f-5im -1.616434f-6 + 5.34008f-6im;;; -6.1624734f-5 - 3.732944f-5im 6.1810475f-5 - 3.1218115f-5im … -4.9335787f-5 + 4.8775437f-5im -1.1586824f-5 + 9.819348f-5im; 9.4498835f-5 + 9.1513575f-6im 5.4773263f-6 + 0.00010452657f0im … -7.656076f-5 + 4.135589f-6im 1.1565084f-5 - 3.6884543f-5im; … ; -9.850406f-5 - 6.4259584f-6im -1.1872624f-5 + 3.61919f-5im … 6.889389f-5 + 1.2036486f-5im -8.455465f-5 - 1.6414699f-5im; -2.387818f-5 - 4.640175f-5im 7.730818f-5 + 2.9390576f-6im … 3.372025f-5 + 5.0800707f-5im 2.708711f-5 - 7.482132f-5im;;; 7.3260315f-5 + 1.6199345f-5im 6.52274f-5 - 4.186813f-5im … 3.0154442f-5 + 8.638988f-5im 4.5611414f-5 - 5.257235f-5im; -6.227766f-5 - 4.8867805f-6im -5.459399f-5 + 4.9015922f-5im … 4.7976457f-5 + 6.379897f-5im 2.8019793f-5 + 6.299835f-5im; … ; 2.589092f-5 - 7.8796795f-5im 7.633735f-5 - 2.6362904f-6im … -1.6249389f-5 - 2.5283814f-5im 8.861713f-5 + 2.1144064f-5im; 6.8642374f-5 + 3.893205f-6im 8.798143f-5 - 2.5031564f-5im … -2.5459434f-5 + 1.8995692f-5im 5.3492993f-5 - 6.0477447f-5im;;; … ;;; 6.50922f-5 - 3.326501f-5im -7.6603574f-5 - 1.6923244f-5im … 5.7626894f-5 - 6.0852326f-6im -3.5414945f-5 - 5.3792923f-5im; -5.3198863f-5 + 5.9266225f-5im -2.8894989f-5 + 8.2308245f-5im … -7.4408163f-6 - 4.9736773f-5im 3.2399796f-5 - 5.4662116f-5im; … ; -6.440678f-8 + 3.6589336f-6im -1.421768f-5 + 6.055217f-5im … -3.280826f-5 - 7.446019f-5im 5.879155f-5 - 1.2513199f-5im; 6.867684f-5 - 3.9667728f-5im -4.3705062f-5 + 7.728471f-5im … 8.244382f-5 - 3.0898707f-5im 4.9038623f-5 + 3.1231502f-5im;;; -1.458577f-5 + 9.9281424f-5im 6.4253836f-5 + 1.0780408f-5im … -2.3170694f-5 + 6.782026f-5im 8.061423f-5 + 1.6654543f-5im; -3.867699f-5 + 6.688157f-5im -6.730421f-6 + 8.791595f-5im … 4.267528f-5 + 5.4691904f-5im 2.8168943f-6 - 6.925724f-5im; … ; -2.6061913f-5 + 6.141426f-5im 4.1995263f-5 + 1.1248376f-5im … 6.0645943f-5 - 5.481925f-7im -4.5920227f-5 - 4.934687f-5im; 6.592287f-6 + 4.9551636f-5im -4.7090594f-5 + 1.9397645f-5im … 1.0037176f-5 - 6.684519f-5im 1.1523247f-5 + 3.614545f-5im;;; 3.6600177f-6 + 7.074899f-5im -2.4387176f-5 + 4.959978f-5im … -1.1466909f-6 - 8.30537f-5im 2.849409f-5 + 7.538601f-5im; -3.9580067f-5 + 3.676177f-5im -5.508896f-5 - 6.24335f-5im … 5.646004f-5 + 5.7345744f-5im 2.5321126f-5 + 4.316189f-5im; … ; 8.607062f-5 - 3.142534f-5im 1.9495797f-5 + 6.745763f-5im … -1.4650599f-5 + 0.00010065971f0im -5.7075667f-6 - 5.4397722f-5im; 3.6714016f-5 + 6.6609384f-5im -9.440959f-5 + 7.6085416f-7im … 6.2779363f-6 + 5.432056f-5im -5.9384f-5 - 3.5526748f-5im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.17263411 0.043085508 … 0.03260647 0.19657122;;; 0.06888782 -0.10384348 … -0.28587458 0.20805295;;; 0.0020898944 0.18421161 … -0.14756325 -0.21304844;;; … ;;; -0.23145278 -0.23087166 … -0.25777978 0.26293257;;; 0.06696138 0.15810463 … -0.047206238 -0.19466682;;; 0.11592762 -0.045648187 … -0.09385318 0.022985846]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.106933236, -0.16910672, 0.14011627, -0.1304469, 0.12474412, 0.033216204, -0.08576646, -0.023677994, -0.047915425, -0.11245662, 0.09184422, 0.016340585, 0.022177124, -0.1632185, 0.048346188, 0.040631626])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.22070241 0.30142298 … 0.35748792 0.41164687;;; -0.38684866 -0.38444337 … -0.3122665 -0.2486598;;; -0.044328347 0.10726601 … 0.20341718 0.39745396;;; … ;;; 0.32402688 0.0043083397 … 0.102944665 -0.21082422;;; 0.4282386 0.22701642 … -0.2626183 -0.31005237;;; -0.36259598 -0.17238592 … 0.07023377 -0.14860287]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.20234892, -0.017347097, -0.040958226, -0.20707354, 0.073229045, -0.09003842, -0.016992807, 0.12758362, 0.029405802, 0.22385916  …  -0.07426581, -0.13049698, -0.12669426, -0.10751408, 0.24238503, -0.10711083, 0.13574186, 0.12271884, 0.12006426, 0.13558432])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.100846544 0.14981501 … -0.34763628 -0.39044058]),))), projection = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.26094767 0.11135319 … -0.18765104 -0.10855315;;; 0.016792558 -0.00013950393 … 0.07354708 -0.17105886]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.08591374, 0.08166262])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-1.1565499 -0.25906837;;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.5660475])))), (positional_embedding = NamedTuple(), lifting = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), fno_blocks = (layer_1 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple()), layer_2 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple()), layer_3 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple()), layer_4 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple())), projection = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))

Training

dataloader = DataLoader((x_data, y_data); batchsize=128, shuffle=true) |> xdev;

function train_model!(model, ps, st, dataloader; epochs=1000)
    train_state = Training.TrainState(model, ps, st, Adam(0.0001f0))

    for epoch in 1:epochs
        loss = -Inf
        for data in dataloader
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoEnzyme(), MAELoss(), data, train_state; return_gradients=Val(false)
            )
        end

        if epoch % 100 == 1 || epoch == epochs
            @printf("Epoch %d: loss = %.6e\n", epoch, loss)
        end
    end

    return train_state.parameters, train_state.states
end

ps_trained, st_trained = train_model!(fno, ps, st, dataloader)
Epoch 1: loss = 5.989288e-01
Epoch 101: loss = 1.178395e-02
Epoch 201: loss = 9.333905e-03
Epoch 301: loss = 1.293517e-02
Epoch 401: loss = 6.285025e-03
Epoch 501: loss = 6.247825e-03
Epoch 601: loss = 1.078401e-02
Epoch 701: loss = 5.986111e-03
Epoch 801: loss = 5.296726e-03
Epoch 901: loss = 4.756196e-03
Epoch 1000: loss = 6.025658e-03

Plotting

using CairoMakie, AlgebraOfGraphics
const AoG = AlgebraOfGraphics
AoG.set_aog_theme!()

x_data_dev = x_data |> xdev;
y_data_dev = y_data |> xdev;

grid = range(0, 1; length=grid_size)
pred = first(
    Reactant.with_config(;
        convolution_precision=PrecisionConfig.HIGH,
        dot_general_precision=PrecisionConfig.HIGH,
    ) do
        @jit(fno(x_data_dev, ps_trained, st_trained))
    end
) |> cdev

data_sequence, sequence, repeated_grid, label = Float32[], Int[], Float32[], String[]
for i in 1:16
    append!(repeated_grid, repeat(grid, 2))
    append!(sequence, repeat([i], grid_size * 2))
    append!(label, repeat(["Ground Truth"], grid_size))
    append!(label, repeat(["Predictions"], grid_size))
    append!(data_sequence, vec(y_data[:, 1, i]))
    append!(data_sequence, vec(pred[:, 1, i]))
end
plot_data = (; data_sequence, sequence, repeated_grid, label)

draw(
    AoG.data(plot_data) *
    mapping(
        :repeated_grid => L"x",
        :data_sequence => L"u(x)";
        color=:label => "",
        layout=:sequence => nonnumeric,
        linestyle=:label => "",
    ) *
    visual(Lines; linewidth=4),
    scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash]));
    figure=(;
        size=(1024, 1024),
        title="Using FNO to solve the Burgers equation",
        titlesize=25,
    ),
    axis=(; xlabelsize=25, ylabelsize=25),
    legend=(; label=L"u(x)", position=:bottom, labelsize=20),
)
Example block output