using DataDeps, MAT, MLUtils
using PythonCall, CondaPkg # For `gdown`
using Printf
const gdown = pyimport("gdown")
register(
DataDep(
"Burgers",
"""
Burgers' equation dataset from
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
mapping between initial conditions to the solutions at the last point of time \
evolution in some function space.
u(x,0) -> u(x, time_end):
* `a`: initial conditions u(x,0)
* `u`: solutions u(x,t_end)
""",
"https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe",
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd";
fetch_method=(url, local_dir) -> begin
pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip")))
end,
post_fetch_method=unpack,
),
)
filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat")
const N = 2048
const Δsamples = 2^3
const grid_size = div(2^13, Δsamples)
const T = Float32
file = matopen(filepath)
x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :)
y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :)
close(file)
x_data = reshape(permutedims(x_data, (2, 1)), grid_size, 1, N);
y_data = reshape(permutedims(y_data, (2, 1)), grid_size, 1, N);
1024×1×2048 Array{Float32, 3}:
[:, :, 1] =
0.6834662
0.6850582
0.68663555
0.68819803
0.6897452
0.6912767
0.6927921
0.6942911
0.6957732
0.697238
0.69868517
0.70011413
0.70152456
0.7029159
0.7042878
0.7056398
0.70697135
0.70828193
0.7095711
0.7108384
0.71208316
0.713305
0.71450317
0.7156773
0.7168267
0.71795076
0.719049
0.7201206
0.721165
0.7221816
0.7231697
0.7241285
0.72505736
0.72595555
0.72682226
0.7276568
0.7284582
0.7292259
0.7299588
0.7306561
0.731317
0.7319404
0.73252547
0.73307127
0.73357666
0.7340406
0.73446214
0.7348401
⋮
0.5978408
0.59978867
0.6017345
0.60367817
0.6056194
0.6075579
0.6094936
0.6114262
0.6133555
0.6152813
0.61720335
0.61912143
0.62103534
0.62294483
0.62484974
0.62674975
0.62864465
0.63053423
0.6324182
0.6342964
0.6361686
0.63803446
0.6398938
0.6417463
0.6435918
0.64542997
0.6472606
0.64908344
0.6508982
0.6527046
0.6545024
0.6562913
0.65807104
0.6598413
0.66160184
0.6633524
0.66509265
0.66682225
0.66854095
0.6702484
0.6719444
0.6736285
0.6753004
0.6769599
0.67860657
0.68024004
0.68186
[:, :, 2] =
0.7511075
0.75031614
0.74948686
0.74862
0.7477157
0.7467743
0.7457962
0.7447817
0.74373096
0.7426446
0.74152285
0.74036616
0.7391749
0.7379496
0.73669064
0.73539853
0.7340737
0.7327167
0.731328
0.7299083
0.728458
0.72697765
0.725468
0.7239295
0.72236294
0.7207688
0.7191478
0.7175006
0.7158279
0.7141304
0.7124088
0.7106638
0.70889616
0.7071066
0.7052959
0.70346487
0.70161426
0.6997448
0.69785744
0.69595283
0.69403183
0.6920953
0.690144
0.6881789
0.6862007
0.68421024
0.6822084
0.6801961
⋮
0.74433285
0.7453439
0.7463201
0.7472613
0.74816704
0.74903715
0.74987143
0.7506696
0.7514314
0.7521566
0.75284505
0.7534964
0.7541106
0.75468725
0.75522625
0.7557274
0.7561905
0.7566154
0.7570019
0.7573498
0.757659
0.7579294
0.75816077
0.75835305
0.75850606
0.7586197
0.758694
0.7587287
0.7587238
0.7586793
0.75859505
0.75847113
0.7583074
0.7581039
0.75786066
0.7575776
0.7572549
0.7568925
0.75649047
0.756049
0.75556797
0.7550476
0.7544881
0.75388944
0.7532519
0.75257564
0.75186074
[:, :, 3] =
-0.41533458
-0.4158413
-0.41631535
-0.416757
-0.4171665
-0.41754422
-0.41789043
-0.41820538
-0.4184894
-0.41874272
-0.41896564
-0.4191584
-0.4193213
-0.41945457
-0.41955847
-0.41963327
-0.41967916
-0.41969648
-0.41968536
-0.4196461
-0.41957894
-0.41948405
-0.4193617
-0.41921213
-0.4190355
-0.4188321
-0.41860205
-0.41834566
-0.41806307
-0.4177545
-0.41742018
-0.4170603
-0.416675
-0.41626456
-0.41582915
-0.41536897
-0.41488418
-0.414375
-0.4138416
-0.41328418
-0.41270295
-0.41209805
-0.41146967
-0.41081804
-0.4101433
-0.4094456
-0.4087252
-0.40798223
⋮
-0.34818122
-0.35062543
-0.3530205
-0.35536665
-0.3576641
-0.3599132
-0.36211416
-0.36426732
-0.36637297
-0.36843148
-0.3704431
-0.37240824
-0.37432718
-0.37620035
-0.37802806
-0.3798107
-0.3815486
-0.3832422
-0.38489184
-0.3864979
-0.3880608
-0.3895809
-0.3910586
-0.39249426
-0.39388832
-0.39524114
-0.3965531
-0.3978246
-0.39905605
-0.40024778
-0.4014002
-0.40251374
-0.4035887
-0.40462554
-0.40562454
-0.40658614
-0.40751064
-0.40839848
-0.40924993
-0.4100654
-0.41084528
-0.4115898
-0.4122994
-0.41297436
-0.41361505
-0.41422176
-0.41479483
;;; …
[:, :, 2046] =
0.016046494
0.008860671
0.0016381989
-0.005616474
-0.012898825
-0.020204268
-0.02752816
-0.034865808
-0.04221249
-0.04956345
-0.056913916
-0.06425911
-0.07159427
-0.07891461
-0.08621542
-0.09349197
-0.1007396
-0.107953705
-0.11512973
-0.12226319
-0.1293497
-0.13638492
-0.14336464
-0.15028474
-0.15714122
-0.1639302
-0.17064787
-0.17729065
-0.18385498
-0.19033752
-0.19673505
-0.20304449
-0.2092629
-0.2153875
-0.22141568
-0.22734497
-0.23317306
-0.2388978
-0.24451719
-0.25002939
-0.25543272
-0.26072565
-0.2659068
-0.27097496
-0.27592903
-0.28076813
-0.2854914
-0.29009825
⋮
0.25936982
0.25655848
0.25364703
0.25063428
0.2475191
0.24430047
0.24097733
0.23754878
0.23401396
0.23037204
0.22662233
0.22276421
0.21879712
0.21472062
0.21053438
0.20623814
0.20183176
0.19731523
0.19268864
0.1879522
0.18310627
0.1781513
0.17308788
0.16791677
0.16263887
0.15725517
0.15176687
0.14617528
0.14048186
0.13468826
0.12879623
0.12280775
0.11672488
0.110549875
0.10428515
0.097933255
0.091496915
0.08497899
0.07838251
0.071710624
0.06496665
0.05815404
0.051276375
0.044337366
0.03734086
0.030290816
0.0231913
[:, :, 2047] =
0.36310026
0.36462003
0.36611378
0.36758128
0.36902225
0.3704365
0.37182376
0.37318385
0.37451655
0.37582165
0.3770989
0.37834814
0.37956914
0.38076174
0.3819257
0.3830609
0.38416716
0.38524425
0.38629207
0.38731045
0.38829923
0.3892583
0.39018747
0.39108667
0.39195576
0.39279464
0.3936032
0.39438137
0.39512908
0.39584625
0.3965328
0.3971887
0.3978139
0.39840838
0.39897212
0.39950514
0.40000743
0.40047896
0.40091982
0.40133002
0.40170965
0.40205875
0.40237737
0.40266564
0.40292367
0.40315154
0.40334943
0.40351745
⋮
0.26702473
0.2694785
0.27191815
0.27434352
0.27675432
0.27915034
0.28153133
0.28389704
0.28624728
0.28858176
0.29090023
0.2932025
0.29548824
0.29775727
0.30000934
0.30224413
0.30446145
0.30666104
0.30884263
0.31100595
0.31315073
0.31527677
0.31738377
0.31947145
0.32153958
0.32358786
0.3256161
0.32762393
0.32961118
0.33157754
0.33352274
0.33544654
0.33734864
0.3392288
0.34108678
0.34292224
0.344735
0.34652475
0.34829122
0.35003418
0.35175335
0.3534485
0.35511935
0.35676563
0.3583871
0.35998356
0.36155468
[:, :, 2048] =
1.0363716
1.0398366
1.0432957
1.0467485
1.0501951
1.0536352
1.0570688
1.0604957
1.0639158
1.067329
1.0707351
1.074134
1.0775255
1.0809096
1.084286
1.0876546
1.0910151
1.0943676
1.0977118
1.1010476
1.1043746
1.107693
1.1110022
1.1143023
1.1175929
1.1208739
1.1241452
1.1274062
1.1306571
1.1338974
1.137127
1.1403457
1.143553
1.1467489
1.1499329
1.1531048
1.1562642
1.159411
1.1625447
1.1656651
1.1687717
1.1718643
1.1749424
1.1780056
1.1810535
1.1840858
1.187102
1.1901015
⋮
0.86836725
0.8720194
0.8756691
0.8793164
0.8829611
0.88660324
0.8902427
0.89387953
0.89751357
0.9011448
0.9047731
0.90839857
0.912021
0.91564035
0.91925657
0.9228697
0.92647946
0.930086
0.9336891
0.9372888
0.94088495
0.94447756
0.9480665
0.95165163
0.95523304
0.95881057
0.96238416
0.9659537
0.9695192
0.97308046
0.9766375
0.98019016
0.9837384
0.9872822
0.99082136
0.9943559
0.99788564
1.0014105
1.0049305
1.0084454
1.0119551
1.0154597
1.0189589
1.0224527
1.025941
1.0294236
1.0329006
using Lux, NeuralOperators, Optimisers, Random, Reactant
const cdev = cpu_device()
const xdev = reactant_device(; force=true)
fno = FourierNeuralOperator(
(16,), 2, 1, 32; activation=gelu, stabilizer=tanh
)
ps, st = Lux.setup(Random.default_rng(), fno) |> xdev;
((positional_embedding = NamedTuple(), lifting = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.45048642 0.8215421;;; -0.6469188 0.5172023;;; -0.50672746 -0.2251062;;; … ;;; 0.56619805 -0.24967276;;; 0.9645864 1.0885009;;; -0.4685063 1.1935192]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.159703, -0.3509462, -0.10126159, 0.39295563, 0.6632423, -0.49430946, -0.21302897, 0.3522459, -0.08633116, -0.12322027 … 0.55145764, -0.27602097, 0.05406876, -0.5174525, -0.5605669, -0.39052242, -0.57962865, 0.31826562, 0.29890722, -0.2686216])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.11259408 0.07066951 … 0.20654172 0.16918825;;; -0.15999952 0.115958184 … 0.08281645 -0.024991611;;; 0.16204435 0.074184105 … -0.21040425 -0.051126767;;; … ;;; 0.15478508 0.19775072 … -0.027014513 -0.041541196;;; 0.20725302 -0.07181448 … 0.19796564 -0.02747813;;; 0.105362296 -0.019454014 … -0.07874177 -0.15154172]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.07033114, 0.07001881, 0.053533047, -0.085907385, -0.05988024, -0.104506835, 0.10887393, 0.09060454, 0.054366186, 0.029574394 … -0.11592382, 0.05252181, -0.026414692, 0.03713721, 0.07035078, 0.1113888, 0.024390548, -0.053051412, 0.006304696, 0.03983806]))), fno_blocks = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.15519148 0.03921681 … -0.12586261 -0.04448573;;; 0.12898631 -0.10303989 … -0.10187342 -0.27491853;;; 0.02485181 -0.1949206 … -0.1056301 0.15868787;;; … ;;; -0.2679895 0.18788193 … 0.2659741 0.14570174;;; 0.29236 0.304122 … 0.084096484 0.1277344;;; -0.28771406 0.18056881 … -0.09567626 0.14510007]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[1.6706108f-6 + 3.7705686f-6im 5.393684f-5 + 5.405541f-5im … -5.6601595f-5 + 1.6648817f-5im 4.5352645f-6 - 0.00011626167f0im; 2.862091f-5 - 4.164547f-5im 6.337899f-5 + 2.637622f-5im … 4.6356254f-5 - 4.597543f-5im 6.231068f-5 + 7.540344f-6im; … ; 5.0573588f-5 - 2.8344606f-5im -4.3950982f-5 + 7.612103f-5im … -1.1755968f-5 - 1.8446735f-6im 8.267337f-5 - 1.4497426f-5im; -3.1217132f-6 - 4.641679f-5im 9.024399f-6 + 5.1104173f-5im … -6.281353f-5 - 4.097029f-5im 4.883649f-5 - 1.0310134f-5im;;; 4.067016f-5 + 4.6861256f-5im 5.3112322f-5 + 2.088728f-5im … 3.1648378f-6 + 5.6358316f-5im -5.4386008f-5 - 3.1795455f-5im; 3.5243887f-5 + 3.7235637f-5im -3.2949734f-5 - 3.5800193f-5im … -5.9887505f-5 - 3.5702877f-5im -4.1863525f-5 - 2.3692563f-5im; … ; 6.912688f-5 - 2.210999f-5im 4.446139f-5 + 8.5168795f-6im … 6.2987885f-5 + 4.1943254f-5im 1.4032827f-5 - 8.796392f-5im; -7.19161f-5 + 1.0448595f-5im -5.253045f-5 - 2.7749134f-5im … 1.0235315f-5 + 6.2976986f-5im -9.4385505f-5 + 1.1969496f-5im;;; 8.3935905f-5 + 3.2144402f-5im 6.890841f-6 - 6.885703f-5im … -3.4504912f-5 - 1.483575f-6im -8.80132f-5 - 1.5442638f-5im; 4.2643573f-5 + 7.7698525f-5im -2.3793255f-5 - 9.143041f-5im … 3.699751f-5 - 1.5827543f-5im 2.767871f-5 - 5.1062605f-5im; … ; 5.1341958f-5 - 4.6251087f-5im 7.125769f-5 + 4.257042f-5im … -1.0617194f-5 - 8.4749394f-5im -9.078296f-5 - 2.4763845f-5im; 4.392465f-5 - 2.1993168f-5im -7.678762f-6 + 2.0259955f-5im … 4.7975336f-5 - 4.2107218f-5im -2.7629132f-5 + 2.4924178f-5im;;; … ;;; 4.8837966f-5 - 6.969894f-5im -4.2340835f-6 - 0.000104790684f0im … 4.47099f-5 + 3.5927515f-6im 4.099971f-5 - 2.5889727f-5im; 7.722521f-5 - 2.335808f-5im 2.1036707f-5 - 6.242976f-5im … -2.2125707f-5 + 4.0439438f-5im -7.802892f-5 - 3.6993944f-5im; … ; 2.9934658f-5 - 7.948678f-5im -4.263506f-5 - 7.528477f-5im … 2.556599f-5 - 4.1384956f-5im -7.529857f-5 - 1.76541f-5im; -7.943425f-6 - 0.00010962754f0im 8.18513f-5 - 3.109182f-5im … 6.5086424f-6 + 5.241695f-5im -8.856964f-5 - 3.145697f-5im;;; 0.000100476216f0 + 1.5000318f-5im 9.4583025f-5 + 2.4858236f-5im … -7.4586096f-5 - 4.0343926f-5im -1.2343749f-5 - 1.7625076f-5im; -6.634201f-5 - 5.4357646f-5im -1.4078752f-5 + 9.528179f-6im … -4.9549635f-6 + 3.409622f-5im -1.0734155f-5 + 5.438907f-5im; … ; -6.5526547f-6 - 0.0001052922f0im 4.068148f-5 - 5.2158663f-5im … 3.4408673f-5 - 8.155139f-5im -6.4049535f-5 - 2.0215586f-5im; 4.1494786f-5 + 2.4782115f-5im 4.4294313f-5 + 6.355632f-5im … 1.06440275f-5 - 3.8804545f-5im -4.2079046f-7 + 7.769853f-5im;;; -2.8860217f-5 - 3.465412f-5im 8.955141f-5 + 4.2313914f-6im … -4.0177067f-5 + 4.196704f-5im 3.345056f-6 - 0.000113107824f0im; -7.778642f-5 - 4.989706f-7im -3.0086718f-5 - 1.3005316f-5im … -8.830706f-6 + 4.730277f-5im 7.641059f-5 + 3.0003139f-5im; … ; -4.0510437f-5 - 7.608402f-5im -9.651915f-5 + 2.6000416f-6im … 5.6149947f-6 + 9.620583f-5im 7.7865305f-5 - 1.9123298f-5im; -2.7168651f-5 - 9.119702f-5im -9.407774f-5 + 2.0953885f-6im … -4.9984563f-5 - 6.707529f-5im 9.547421f-5 + 8.83879f-6im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.11306852 -0.09494071 … 0.00458673 0.2769081;;; -0.021850213 -0.052134648 … 0.1691181 -0.19033004;;; -0.18775013 0.28516403 … 0.24153912 0.2923376;;; … ;;; 0.08753696 -0.0086831525 … 0.12428314 0.15478724;;; -0.03052154 -0.15229201 … 0.28737718 -0.08045661;;; 0.27243385 0.29067954 … 0.18422611 -0.068228625]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.0031418367, 0.10188453, -0.12537697, 0.06675587, -0.013257144, -0.13640673, 0.15879893, 0.054921497, -0.1348415, 0.16932376, -0.12173626, 0.06348767, -0.03423934, -0.082781404, 0.13630433, -0.108549275])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.22665338 -0.3400378 … 0.08706145 0.4081696;;; 0.28968003 -0.08522588 … 0.038577355 -0.041299798;;; 0.19068548 -0.28463152 … 0.33406883 0.04204296;;; … ;;; -0.3132675 0.24250282 … -0.32430583 0.1269693;;; 0.22419181 0.35890624 … 0.18207164 -0.17319195;;; 0.22345763 0.38342932 … 0.2773201 -0.09895574]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.15528032, -0.12397516, 0.08258507, 0.008148372, 0.09858009, -0.004576862, 0.13227305, -0.2141046, 0.21508905, -0.23343095 … 0.17101729, 0.044703335, -0.23205146, 0.22060686, -0.1241993, 0.09057075, -0.002218455, -0.16072816, -0.021555781, -0.18396485])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.26722625 0.09803289 … 0.23878174 0.24928363]),)), layer_2 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.08004587 -0.2757149 … 0.05531754 0.107396856;;; -0.15095755 0.16446148 … 0.15787834 0.10056594;;; -0.12946928 0.30604163 … 0.018036047 -0.078015506;;; … ;;; 0.2436746 -0.11337698 … -0.09482865 -0.0037036065;;; -0.21400957 0.17087924 … -0.024958976 -0.2102636;;; 0.024612296 -0.24273801 … -0.021588214 0.058255482]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[-6.944917f-5 + 1.9736479f-5im 1.100923f-5 - 7.263282f-5im … 4.892848f-5 - 5.77065f-5im 3.1812597f-6 + 5.6345794f-5im; -2.1512416f-5 - 7.467682f-5im 8.529729f-6 - 3.2645527f-5im … -2.5659378f-5 - 6.368989f-5im -5.6208548f-5 + 3.708458f-5im; … ; 4.167106f-5 + 6.52251f-5im 1.4755933f-6 + 0.00011084194f0im … -3.392622f-5 - 4.461022f-5im 1.684027f-5 + 5.7103847f-5im; -2.5454276f-5 - 6.556705f-5im -2.1250962f-5 + 8.60897f-5im … 5.2762276f-5 - 3.165503f-5im 4.685324f-6 - 5.2981966f-5im;;; -9.145959f-5 + 3.0509764f-5im 2.9528528f-5 + 8.036611f-5im … 1.2142962f-5 - 4.471757f-5im -3.213906f-5 - 4.1843603f-5im; -1.7887585f-5 + 3.0196657f-5im 1.0977812f-5 - 4.0769773f-5im … -0.0001118833f0 - 6.5970526f-6im -6.9403985f-5 + 8.4880085f-6im; … ; -4.4718734f-5 - 1.3996221f-5im 3.1580166f-5 - 2.8086048f-5im … 3.6186539f-6 - 3.0274678f-6im -0.00011013642f0 + 9.86649f-7im; 5.3858108f-5 - 1.0358344f-6im -9.969797f-5 + 1.311252f-5im … 5.0294708f-5 - 1.9420258f-5im -1.6821134f-5 - 9.639828f-5im;;; -7.424115f-5 + 3.778863f-5im -7.055879f-5 - 2.5559959f-5im … 1.0271251f-6 - 4.866036f-5im 4.6190078f-5 - 4.0562518f-5im; 9.8097764f-5 - 8.1190956f-7im 2.1967346f-5 - 1.7651364f-5im … -8.8720844f-7 - 1.602708f-5im 4.256406f-6 + 6.1615356f-5im; … ; -2.622712f-5 - 7.103561f-6im 2.4150067f-6 - 5.8742953f-5im … 2.1835142f-5 + 2.6665657f-7im 9.4419665f-5 - 2.4973546f-5im; -1.6155216f-5 - 8.3839055f-5im -1.2156954f-5 - 6.2221014f-5im … 3.0401636f-5 - 3.0999967f-5im -5.4792974f-5 + 1.0731754f-5im;;; … ;;; -2.9812618f-5 - 2.2633736f-5im 0.0001046964f0 + 1.51716085f-5im … -1.4597063f-5 + 5.918934f-5im -5.440328f-6 - 6.4585547f-6im; -7.352892f-5 - 4.8792135f-6im 5.1517854f-6 - 5.8969395f-5im … 6.47573f-5 + 1.2595454f-5im -3.0900854f-5 + 1.9970386f-5im; … ; -1.5903555f-5 + 5.9727012f-5im 1.5637932f-5 + 5.170123f-5im … 1.3866971f-5 + 2.208515f-5im -8.954601f-6 + 4.614864f-5im; -8.066736f-5 + 4.034399f-5im 2.9015537f-5 - 4.5085944f-5im … -1.7889834f-5 - 9.325097f-5im -0.00011028525f0 - 7.174116f-6im;;; -7.428425f-5 - 4.179818f-5im -2.0538653f-5 - 2.777356f-5im … -0.00011137294f0 + 9.2450355f-6im -2.3159453f-5 - 9.7808625f-5im; 9.113123f-5 - 1.1671502f-5im -3.3895078f-5 + 1.8190782f-5im … -6.801301f-5 + 6.7740402f-6im -5.4041448f-6 - 0.00010489095f0im; … ; -7.238901f-5 - 4.187695f-5im -0.00010864493f0 - 5.194488f-6im … 7.316885f-6 + 2.087767f-5im -1.3714547f-5 - 5.94768f-5im; -8.087518f-6 + 5.403768f-5im -2.006667f-5 - 4.474941f-5im … -4.4717308f-5 - 5.766889f-5im 4.5228902f-5 - 3.321001f-5im;;; 8.740259f-6 - 6.448863f-5im 4.0860257f-5 + 5.829089f-5im … -1.7944229f-5 - 2.4127672f-5im -9.2310074f-8 + 1.8880943f-5im; -4.6131398f-5 + 7.192234f-5im -0.000108625354f0 + 7.0528986f-6im … -2.4623252f-5 + 8.8807225f-5im 9.830481f-6 - 2.5502544f-5im; … ; -7.030695f-5 - 3.4626515f-5im 0.00011120639f0 + 3.884197f-6im … -6.2605155f-5 + 5.1724048f-5im -1.0957519f-5 + 2.2385648f-5im; -8.973876f-5 - 7.745657f-6im -6.805112f-5 + 1.1435353f-5im … -6.841988f-5 + 3.8024664f-6im 9.946275f-5 - 1.3466561f-5im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.15283357 0.23583993 … -0.16798583 0.005778974;;; -0.30492485 0.08329239 … -0.027095117 0.1116096;;; -0.20075685 -0.1971108 … -0.045551717 0.17903098;;; … ;;; -0.08855032 -0.17230253 … -0.19003019 0.017790072;;; -0.14727323 0.24199045 … -0.007359106 -0.26266348;;; -0.30109015 0.05572032 … 0.054540124 0.17849751]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.17645743, 0.096673004, 0.11225674, 0.11067301, 0.17134006, 0.123067744, -0.10854124, 0.017893909, -0.13743314, -0.09237544, -0.017989434, 0.09168463, 0.015791897, -0.16021779, 0.11211679, -0.122612745])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.09433882 0.42526373 … 0.33682358 -0.31651193;;; -0.15883683 0.038021676 … 0.057269365 -0.35315865;;; 0.40136573 -0.103728294 … -0.28516972 -0.2486246;;; … ;;; -0.30100456 -0.34606373 … 0.20537008 -0.290084;;; 0.31450155 -0.37467518 … -0.13156639 0.22070329;;; 0.22435245 0.15107821 … 0.14122795 0.1747459]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.0072006583, 0.14925969, 0.0390113, -0.10223663, 0.2381213, -0.1569022, 0.115956366, -0.13106886, -0.23001942, 0.059086323 … 0.06210494, 0.13298455, 0.1716491, -0.21588913, 0.03526342, 0.105377704, -0.18904746, 0.0028511584, -0.1472426, -0.24608746])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.28504488 0.27378538 … 0.19929509 0.17291225]),)), layer_3 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.20224383 0.30429968 … 0.11955541 0.23232342;;; -0.031039223 0.118525155 … -0.2963487 -0.27195686;;; -0.16330275 0.105522566 … -0.30596742 -0.031536978;;; … ;;; -0.16481528 0.23382318 … 0.26429492 0.29160932;;; 0.090146 0.012240174 … -0.15922567 -0.28099555;;; 0.15669882 -0.29040548 … -0.27592134 0.07088909]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[0.000110564826f0 - 5.454509f-6im 6.8007255f-5 - 4.482572f-5im … 9.805285f-5 + 2.0983738f-5im 6.4375236f-5 - 4.5411325f-5im; -5.99966f-6 + 7.274222f-5im -3.9460792f-5 - 8.252176f-5im … 5.5667988f-5 - 3.0002047f-5im -2.5727059f-6 + 2.7192873f-5im; … ; 4.553733f-5 + 5.1999086f-7im -8.190516f-6 + 0.000103613274f0im … -7.285188f-5 - 8.910996f-6im 5.6809688f-5 + 2.8335111f-5im; -1.7999693f-5 + 2.690676f-5im 6.407429f-5 + 8.353025f-6im … 2.0611325f-5 + 4.370806f-6im -1.033298f-5 + 8.7410735f-5im;;; -5.3302574f-6 - 0.00011619046f0im -5.1635798f-6 - 7.52321f-5im … -6.758214f-5 + 1.0702897f-5im 2.5685411f-5 - 8.625889f-5im; 9.2755356f-5 - 2.6642847f-5im -4.120909f-5 - 2.4502107f-5im … 7.367511f-6 - 8.628042f-5im 3.270516f-5 - 2.8237388f-5im; … ; 4.7264097f-5 - 9.324198f-6im 2.8225666f-5 + 2.0643245f-5im … 6.1808096f-6 + 2.5297937f-5im 3.7629652f-6 - 9.44284f-6im; 1.737709f-5 + 2.645469f-5im 2.4473164f-5 - 5.6456483f-6im … -4.5040848f-5 - 1.3286866f-5im -5.5510995f-5 - 2.767775f-5im;;; 4.751976f-5 + 4.64248f-5im -7.800649f-5 + 4.2124957f-5im … 3.419249f-6 + 0.000110278284f0im -3.7994323f-7 - 0.00010614693f0im; 1.1264485f-5 + 9.3096874f-5im 5.5009223f-6 - 9.3283095f-5im … -7.6077777f-6 + 0.00010306079f0im -1.8614897f-5 + 9.4745425f-5im; … ; 5.43923f-5 - 5.1435127f-6im -1.4985562f-5 - 2.5298388f-5im … -5.598792f-5 + 3.6237907f-7im 1.7029379f-7 + 8.274646f-5im; -2.0616892f-5 + 3.2566764f-5im 2.8910275f-5 + 6.002275f-5im … 2.3169465f-5 - 3.3153796f-5im 7.020317f-6 + 5.9708975f-5im;;; … ;;; -8.149517f-5 - 7.5255157f-6im 7.801067f-5 + 3.5232435f-5im … -4.239713f-5 - 2.9810566f-5im 2.7750888f-5 + 8.308316f-5im; 9.205469f-6 + 1.6117032f-5im 1.3856625f-6 + 5.4089833f-5im … 7.382971f-5 - 7.198978f-6im 5.692315f-5 + 5.2508403f-6im; … ; -8.808769f-6 + 2.7394446f-5im -4.8174443f-5 + 6.780092f-5im … 0.000107043255f0 - 5.4330303f-7im 7.5891876f-6 + 5.503303f-5im; 4.0324638f-5 - 5.3116994f-5im -2.9108072f-5 + 3.8146398f-5im … -5.876609f-5 - 5.7509722f-5im 0.00011674235f0 - 4.0791056f-6im;;; -3.4542936f-5 - 5.704088f-5im 9.698466f-6 - 6.894921f-5im … -5.406275f-5 - 6.186083f-5im -8.450297f-6 + 7.340987f-5im; -4.4177803f-5 + 7.087823f-6im -2.472286f-5 - 1.66237f-5im … -2.5827037f-5 + 8.171716f-5im 6.8489564f-5 - 2.6545807f-5im; … ; 3.7620484f-6 - 3.772964f-5im -4.5426503f-5 - 7.17048f-5im … 5.178241f-6 + 5.938711f-5im -5.5727818f-5 + 3.0303454f-5im; -3.896108f-5 + 8.126517f-5im -8.541544f-5 - 2.42667f-6im … 1.5115911f-5 - 3.8225866f-5im -8.545947f-6 - 9.2782895f-5im;;; 4.9035312f-5 + 4.0607978f-5im -4.2508327f-7 - 8.378937f-5im … -5.6288744f-5 + 6.292472f-5im -2.7021488f-5 - 6.83504f-5im; -3.231327f-5 - 8.942472f-6im -3.252503f-5 + 4.4713233f-5im … -5.9218568f-5 + 6.903283f-7im -9.047391f-5 - 2.1144093f-5im; … ; 2.273421f-5 + 2.1825253f-5im -5.030379f-6 - 7.719529f-6im … -4.8301226f-6 - 8.000287f-5im 3.0207448f-6 - 5.5184733f-5im; 4.6934467f-5 + 6.223717f-5im 9.089505f-6 + 5.284071f-5im … -1.8044091f-5 - 7.769596f-5im 7.343585f-5 - 3.1110176f-5im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.03983848 0.19607632 … -0.008741114 0.29659045;;; 0.19948426 0.05835575 … 0.18548733 -0.025122898;;; 0.122004874 0.05456425 … 0.2369841 0.20411663;;; … ;;; -0.2062294 -0.29174805 … -0.12341856 0.09417242;;; 0.265215 0.24679036 … 0.24224252 -0.2040529;;; -0.25798193 0.30339515 … -0.10822388 -0.04247606]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.17644188, -0.015626429, 0.0051075025, -0.13385877, 0.1387758, -0.054379024, -0.13689122, 0.08329212, -0.120442666, 0.13259022, -0.1731755, -0.1393216, -0.020477047, -0.07313665, -0.122828096, -0.07236175])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.116968036 -0.41951877 … 0.014586019 0.2112655;;; -0.2294085 -0.38947496 … 0.07447402 -0.4047335;;; -0.2931729 0.02447795 … -0.24075744 0.3872833;;; … ;;; -0.3733018 0.24538384 … 0.12744063 0.4234627;;; 0.23478365 0.00589124 … -0.1704055 0.31815884;;; 0.08882342 0.36547762 … 0.27109545 -0.2221812]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.034669995, 0.09541288, 0.127505, -0.24526429, 0.23563218, -0.16610855, 0.17962897, -0.13025218, 0.08740431, 0.14996877 … 0.1947349, -0.16076627, -0.17180067, 0.049119025, -0.09261748, -0.029668152, 0.07338387, -0.015861958, -0.05979511, -0.027300924])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.30119717 -0.13142028 … -0.3045292 -0.3208088]),)), layer_4 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.24239708 -0.12292884 … -0.26763082 -0.15369464;;; 0.051489104 -0.18802202 … -0.19783135 0.3035818;;; -0.14299795 -0.18846948 … 0.10985664 -0.072139956;;; … ;;; 0.27763632 -0.1436563 … -0.2264389 -0.13123137;;; -0.18255487 -0.2677296 … -0.04587657 0.15047394;;; 0.18303156 -0.058275923 … 0.05013783 0.053699046]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[3.6989957f-5 + 7.3359544f-5im -6.2819374f-5 - 1.2109311f-5im … 6.4810876f-5 + 1.1637727f-5im -1.3403987f-6 - 8.33921f-5im; -2.3674904f-5 + 5.0757968f-5im -4.910922f-5 + 5.314623f-5im … -9.470064f-5 - 1.7860853f-5im -6.984829f-5 + 5.071309f-5im; … ; 5.5318073f-5 + 4.5203487f-5im 2.5687077f-5 + 2.704918f-5im … 4.754956f-5 + 3.0287527f-5im -3.398097f-5 + 3.095447f-6im; -3.435991f-5 - 1.5760408f-5im 3.2983604f-5 - 7.135105f-5im … 2.6372509f-5 + 5.422435f-5im -1.616434f-6 + 5.34008f-6im;;; -6.1624734f-5 - 3.732944f-5im 6.1810475f-5 - 3.1218115f-5im … -4.9335787f-5 + 4.8775437f-5im -1.1586824f-5 + 9.819348f-5im; 9.4498835f-5 + 9.1513575f-6im 5.4773263f-6 + 0.00010452657f0im … -7.656076f-5 + 4.135589f-6im 1.1565084f-5 - 3.6884543f-5im; … ; -9.850406f-5 - 6.4259584f-6im -1.1872624f-5 + 3.61919f-5im … 6.889389f-5 + 1.2036486f-5im -8.455465f-5 - 1.6414699f-5im; -2.387818f-5 - 4.640175f-5im 7.730818f-5 + 2.9390576f-6im … 3.372025f-5 + 5.0800707f-5im 2.708711f-5 - 7.482132f-5im;;; 7.3260315f-5 + 1.6199345f-5im 6.52274f-5 - 4.186813f-5im … 3.0154442f-5 + 8.638988f-5im 4.5611414f-5 - 5.257235f-5im; -6.227766f-5 - 4.8867805f-6im -5.459399f-5 + 4.9015922f-5im … 4.7976457f-5 + 6.379897f-5im 2.8019793f-5 + 6.299835f-5im; … ; 2.589092f-5 - 7.8796795f-5im 7.633735f-5 - 2.6362904f-6im … -1.6249389f-5 - 2.5283814f-5im 8.861713f-5 + 2.1144064f-5im; 6.8642374f-5 + 3.893205f-6im 8.798143f-5 - 2.5031564f-5im … -2.5459434f-5 + 1.8995692f-5im 5.3492993f-5 - 6.0477447f-5im;;; … ;;; 6.50922f-5 - 3.326501f-5im -7.6603574f-5 - 1.6923244f-5im … 5.7626894f-5 - 6.0852326f-6im -3.5414945f-5 - 5.3792923f-5im; -5.3198863f-5 + 5.9266225f-5im -2.8894989f-5 + 8.2308245f-5im … -7.4408163f-6 - 4.9736773f-5im 3.2399796f-5 - 5.4662116f-5im; … ; -6.440678f-8 + 3.6589336f-6im -1.421768f-5 + 6.055217f-5im … -3.280826f-5 - 7.446019f-5im 5.879155f-5 - 1.2513199f-5im; 6.867684f-5 - 3.9667728f-5im -4.3705062f-5 + 7.728471f-5im … 8.244382f-5 - 3.0898707f-5im 4.9038623f-5 + 3.1231502f-5im;;; -1.458577f-5 + 9.9281424f-5im 6.4253836f-5 + 1.0780408f-5im … -2.3170694f-5 + 6.782026f-5im 8.061423f-5 + 1.6654543f-5im; -3.867699f-5 + 6.688157f-5im -6.730421f-6 + 8.791595f-5im … 4.267528f-5 + 5.4691904f-5im 2.8168943f-6 - 6.925724f-5im; … ; -2.6061913f-5 + 6.141426f-5im 4.1995263f-5 + 1.1248376f-5im … 6.0645943f-5 - 5.481925f-7im -4.5920227f-5 - 4.934687f-5im; 6.592287f-6 + 4.9551636f-5im -4.7090594f-5 + 1.9397645f-5im … 1.0037176f-5 - 6.684519f-5im 1.1523247f-5 + 3.614545f-5im;;; 3.6600177f-6 + 7.074899f-5im -2.4387176f-5 + 4.959978f-5im … -1.1466909f-6 - 8.30537f-5im 2.849409f-5 + 7.538601f-5im; -3.9580067f-5 + 3.676177f-5im -5.508896f-5 - 6.24335f-5im … 5.646004f-5 + 5.7345744f-5im 2.5321126f-5 + 4.316189f-5im; … ; 8.607062f-5 - 3.142534f-5im 1.9495797f-5 + 6.745763f-5im … -1.4650599f-5 + 0.00010065971f0im -5.7075667f-6 - 5.4397722f-5im; 3.6714016f-5 + 6.6609384f-5im -9.440959f-5 + 7.6085416f-7im … 6.2779363f-6 + 5.432056f-5im -5.9384f-5 - 3.5526748f-5im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.17263411 0.043085508 … 0.03260647 0.19657122;;; 0.06888782 -0.10384348 … -0.28587458 0.20805295;;; 0.0020898944 0.18421161 … -0.14756325 -0.21304844;;; … ;;; -0.23145278 -0.23087166 … -0.25777978 0.26293257;;; 0.06696138 0.15810463 … -0.047206238 -0.19466682;;; 0.11592762 -0.045648187 … -0.09385318 0.022985846]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.106933236, -0.16910672, 0.14011627, -0.1304469, 0.12474412, 0.033216204, -0.08576646, -0.023677994, -0.047915425, -0.11245662, 0.09184422, 0.016340585, 0.022177124, -0.1632185, 0.048346188, 0.040631626])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.22070241 0.30142298 … 0.35748792 0.41164687;;; -0.38684866 -0.38444337 … -0.3122665 -0.2486598;;; -0.044328347 0.10726601 … 0.20341718 0.39745396;;; … ;;; 0.32402688 0.0043083397 … 0.102944665 -0.21082422;;; 0.4282386 0.22701642 … -0.2626183 -0.31005237;;; -0.36259598 -0.17238592 … 0.07023377 -0.14860287]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.20234892, -0.017347097, -0.040958226, -0.20707354, 0.073229045, -0.09003842, -0.016992807, 0.12758362, 0.029405802, 0.22385916 … -0.07426581, -0.13049698, -0.12669426, -0.10751408, 0.24238503, -0.10711083, 0.13574186, 0.12271884, 0.12006426, 0.13558432])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.100846544 0.14981501 … -0.34763628 -0.39044058]),))), projection = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.26094767 0.11135319 … -0.18765104 -0.10855315;;; 0.016792558 -0.00013950393 … 0.07354708 -0.17105886]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.08591374, 0.08166262])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-1.1565499 -0.25906837;;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.5660475])))), (positional_embedding = NamedTuple(), lifting = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), fno_blocks = (layer_1 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple()), layer_2 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple()), layer_3 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple()), layer_4 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple())), projection = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))
dataloader = DataLoader((x_data, y_data); batchsize=128, shuffle=true) |> xdev;
function train_model!(model, ps, st, dataloader; epochs=1000)
train_state = Training.TrainState(model, ps, st, Adam(0.0001f0))
for epoch in 1:epochs
loss = -Inf
for data in dataloader
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), MAELoss(), data, train_state; return_gradients=Val(false)
)
end
if epoch % 100 == 1 || epoch == epochs
@printf("Epoch %d: loss = %.6e\n", epoch, loss)
end
end
return train_state.parameters, train_state.states
end
ps_trained, st_trained = train_model!(fno, ps, st, dataloader)
Epoch 1: loss = 5.989288e-01
Epoch 101: loss = 1.178395e-02
Epoch 201: loss = 9.333905e-03
Epoch 301: loss = 1.293517e-02
Epoch 401: loss = 6.285025e-03
Epoch 501: loss = 6.247825e-03
Epoch 601: loss = 1.078401e-02
Epoch 701: loss = 5.986111e-03
Epoch 801: loss = 5.296726e-03
Epoch 901: loss = 4.756196e-03
Epoch 1000: loss = 6.025658e-03
using CairoMakie, AlgebraOfGraphics
const AoG = AlgebraOfGraphics
AoG.set_aog_theme!()
x_data_dev = x_data |> xdev;
y_data_dev = y_data |> xdev;
grid = range(0, 1; length=grid_size)
pred = first(
Reactant.with_config(;
convolution_precision=PrecisionConfig.HIGH,
dot_general_precision=PrecisionConfig.HIGH,
) do
@jit(fno(x_data_dev, ps_trained, st_trained))
end
) |> cdev
data_sequence, sequence, repeated_grid, label = Float32[], Int[], Float32[], String[]
for i in 1:16
append!(repeated_grid, repeat(grid, 2))
append!(sequence, repeat([i], grid_size * 2))
append!(label, repeat(["Ground Truth"], grid_size))
append!(label, repeat(["Predictions"], grid_size))
append!(data_sequence, vec(y_data[:, 1, i]))
append!(data_sequence, vec(pred[:, 1, i]))
end
plot_data = (; data_sequence, sequence, repeated_grid, label)
draw(
AoG.data(plot_data) *
mapping(
:repeated_grid => L"x",
:data_sequence => L"u(x)";
color=:label => "",
layout=:sequence => nonnumeric,
linestyle=:label => "",
) *
visual(Lines; linewidth=4),
scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash]));
figure=(;
size=(1024, 1024),
title="Using FNO to solve the Burgers equation",
titlesize=25,
),
axis=(; xlabelsize=25, ylabelsize=25),
legend=(; label=L"u(x)", position=:bottom, labelsize=20),
)
