using DataDeps, MAT, MLUtils
using PythonCall, CondaPkg # For `gdown`
using Printf
const gdown = pyimport("gdown")
register(
DataDep(
"Burgers",
"""
Burgers' equation dataset from
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
mapping between initial conditions to the solutions at the last point of time \
evolution in some function space.
u(x,0) -> u(x, time_end):
* `a`: initial conditions u(x,0)
* `u`: solutions u(x,t_end)
""",
"https://drive.google.com/uc?id=16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe",
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd";
fetch_method=(url, local_dir) -> begin
pyconvert(String, gdown.download(url, joinpath(local_dir, "Burgers_R10.zip")))
end,
post_fetch_method=unpack,
),
)
filepath = joinpath(datadep"Burgers", "burgers_data_R10.mat")
const N = 2048
const Δsamples = 2^3
const grid_size = div(2^13, Δsamples)
const T = Float32
file = matopen(filepath)
x_data = reshape(T.(collect(read(file, "a")[1:N, 1:Δsamples:end])), N, :)
y_data = reshape(T.(collect(read(file, "u")[1:N, 1:Δsamples:end])), N, :)
close(file)
x_data = reshape(permutedims(x_data, (2, 1)), grid_size, 1, N);
y_data = reshape(permutedims(y_data, (2, 1)), grid_size, 1, N);
1024×1×2048 Array{Float32, 3}:
[:, :, 1] =
0.6834662
0.6850582
0.68663555
0.68819803
0.6897452
0.6912767
0.6927921
0.6942911
0.6957732
0.697238
⋮
0.66854095
0.6702484
0.6719444
0.6736285
0.6753004
0.6769599
0.67860657
0.68024004
0.68186
[:, :, 2] =
0.7511075
0.75031614
0.74948686
0.74862
0.7477157
0.7467743
0.7457962
0.7447817
0.74373096
0.7426446
⋮
0.75649047
0.756049
0.75556797
0.7550476
0.7544881
0.75388944
0.7532519
0.75257564
0.75186074
[:, :, 3] =
-0.41533458
-0.4158413
-0.41631535
-0.416757
-0.4171665
-0.41754422
-0.41789043
-0.41820538
-0.4184894
-0.41874272
⋮
-0.40924993
-0.4100654
-0.41084528
-0.4115898
-0.4122994
-0.41297436
-0.41361505
-0.41422176
-0.41479483
;;; …
[:, :, 2046] =
0.016046494
0.008860671
0.0016381989
-0.005616474
-0.012898825
-0.020204268
-0.02752816
-0.034865808
-0.04221249
-0.04956345
⋮
0.07838251
0.071710624
0.06496665
0.05815404
0.051276375
0.044337366
0.03734086
0.030290816
0.0231913
[:, :, 2047] =
0.36310026
0.36462003
0.36611378
0.36758128
0.36902225
0.3704365
0.37182376
0.37318385
0.37451655
0.37582165
⋮
0.34829122
0.35003418
0.35175335
0.3534485
0.35511935
0.35676563
0.3583871
0.35998356
0.36155468
[:, :, 2048] =
1.0363716
1.0398366
1.0432957
1.0467485
1.0501951
1.0536352
1.0570688
1.0604957
1.0639158
1.067329
⋮
1.0049305
1.0084454
1.0119551
1.0154597
1.0189589
1.0224527
1.025941
1.0294236
1.0329006
using Lux, NeuralOperators, Optimisers, Random, Reactant
const cdev = cpu_device()
const xdev = reactant_device(; force=true)
fno = FourierNeuralOperator(
(16,), 2, 1, 32; activation=gelu, stabilizer=tanh
)
ps, st = Lux.setup(Random.default_rng(), fno) |> xdev;
((positional_embedding = NamedTuple(), lifting = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.45048642 0.8215421;;; -0.6469188 0.5172023;;; -0.50672746 -0.2251062;;; … ;;; 0.56619805 -0.24967276;;; 0.9645864 1.0885009;;; -0.4685063 1.1935192]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.159703, -0.3509462, -0.10126159, 0.39295563, 0.6632423, -0.49430946, -0.21302897, 0.3522459, -0.08633116, -0.12322027 … 0.55145764, -0.27602097, 0.05406876, -0.5174525, -0.5605669, -0.39052242, -0.57962865, 0.31826562, 0.29890722, -0.2686216])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.11259408 0.07066951 … 0.20654172 0.16918825;;; -0.15999952 0.115958184 … 0.08281645 -0.024991611;;; 0.16204435 0.074184105 … -0.21040425 -0.051126767;;; … ;;; 0.15478508 0.19775072 … -0.027014513 -0.041541196;;; 0.20725302 -0.07181448 … 0.19796564 -0.02747813;;; 0.105362296 -0.019454014 … -0.07874177 -0.15154172]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.07033114, 0.07001881, 0.053533047, -0.085907385, -0.05988024, -0.104506835, 0.10887393, 0.09060454, 0.054366186, 0.029574394 … -0.11592382, 0.05252181, -0.026414692, 0.03713721, 0.07035078, 0.1113888, 0.024390548, -0.053051412, 0.006304696, 0.03983806]))), fno_blocks = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.15519148 0.03921681 … -0.12586261 -0.04448573;;; 0.12898631 -0.10303989 … -0.10187342 -0.27491853;;; 0.02485181 -0.1949206 … -0.1056301 0.15868787;;; … ;;; -0.2679895 0.18788193 … 0.2659741 0.14570174;;; 0.29236 0.304122 … 0.084096484 0.1277344;;; -0.28771406 0.18056881 … -0.09567626 0.14510007]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[1.3240584f-5 + 0.00010301658f0im 4.7563102f-5 + 6.397067f-5im … -2.9267023f-5 + 7.020277f-5im -1.8086641f-5 + 8.281484f-5im; 3.517772f-5 + 7.297493f-5im -3.456435f-5 + 3.766403f-5im … 1.6177328f-5 + 6.5934655f-6im 1.0847092f-5 + 2.9113748f-5im; … ; -2.568242f-5 + 0.00010015529f0im -2.9272895f-5 + 4.4885383f-7im … -4.4323846f-5 + 9.014409f-6im 3.4853365f-5 + 8.465401f-5im; 6.3648768f-6 + 5.2385338f-5im -4.3991546f-5 + 7.299616f-5im … 5.0066286f-5 + 0.000101214624f0im 3.766254f-6 + 5.5104312f-5im;;; -5.9960876f-6 + 9.848099f-5im -3.568006f-6 + 8.572471f-5im … -3.552033f-5 + 8.2135055f-5im 4.7455942f-5 + 5.0173236f-5im; -9.978321f-7 + 2.0200037f-5im 3.840388f-5 + 0.00011715801f0im … 1.8490355f-5 + 4.3474334f-5im -5.455334f-5 + 0.00011625391f0im; … ; 2.3761793f-5 + 4.8624555f-5im 2.685906f-5 + 6.870979f-5im … 3.8923732f-5 + 2.074724f-5im 3.8112055f-5 + 0.00010882778f0im; 3.6496764f-5 + 9.4601615f-5im 5.8889185f-5 + 4.7770685f-5im … -2.974226f-5 + 8.4608255f-5im -3.7900463f-6 + 5.0945957f-5im;;; 3.212798f-5 + 6.9564805f-5im 6.917813f-6 + 7.057462f-5im … 3.3662647f-5 + 7.844298f-5im 2.3433691f-5 + 9.697973f-5im; 1.23130885f-5 + 4.0616345f-5im 5.576272f-5 + 1.4512625f-5im … 2.3759254f-5 + 0.00010397891f0im 2.2150089f-5 + 3.0278206f-5im; … ; -4.2371532f-5 + 6.520186f-5im -5.9062775f-5 + 8.917818f-5im … -2.1815336f-5 + 6.449169f-5im 2.4680703f-6 + 0.00011853427f0im; -1.3869256f-5 + 0.00011697971f0im -3.543315f-5 + 9.997508f-6im … 1.0132513f-5 + 8.1168924f-5im 3.1024101f-6 + 2.260533f-5im;;; … ;;; -3.5583987f-5 + 3.116554f-5im 5.9776867f-6 + 1.0330907f-5im … -5.7939018f-5 + 9.202951f-5im 5.104964f-5 + 5.667178f-5im; 5.716859f-5 + 7.02693f-5im 8.699819f-6 + 1.911613f-5im … 4.3638705f-5 + 7.831568f-5im 1.5238453f-5 + 6.2198094f-5im; … ; 1.6387807f-5 + 1.9756786f-5im 5.849621f-5 + 4.4835848f-5im … -4.1531974f-5 + 4.2091873f-5im 3.7959428f-5 + 2.5034722f-5im; 4.9617527f-5 + 6.127898f-5im -2.3755507f-5 + 9.658421f-5im … 2.3432221f-7 + 6.1457205f-5im -5.6365672f-5 + 5.972315f-5im;;; 4.935611f-5 + 1.756812f-5im -4.3119908f-5 + 1.7703336f-5im … 3.965653f-5 + 0.000107071166f0im -2.8707938f-5 + 4.5688867f-5im; 4.4613997f-5 + 2.7643066f-5im 2.7535352f-6 + 8.275172f-5im … 1.176948f-5 + 0.0001050482f0im 3.2012067f-5 + 0.00011471791f0im; … ; 3.465935f-5 + 8.8139706f-5im -1.1828088f-6 + 4.701023f-5im … 6.0475228f-5 + 8.523459f-5im 1.236354f-5 + 9.094838f-5im; 3.6115103f-5 + 6.4707056f-5im 4.6772984f-5 + 4.8728514f-5im … 2.6082693f-5 + 1.5542486f-5im 1.1748816f-5 + 0.00010975819f0im;;; 5.288144f-5 + 4.8509515f-5im 1.846144f-5 + 5.2740506f-5im … -4.8479233f-5 + 0.00010270168f0im 1.954673f-5 + 5.0130657f-5im; 8.65585f-6 + 4.9134294f-5im -1.7333303f-5 + 0.00010963712f0im … -3.5263372f-5 + 6.492513f-5im 3.3327407f-5 + 6.758813f-5im; … ; 1.2167591f-5 + 0.000118909236f0im -2.4715642f-5 + 2.6281014f-5im … 2.3428256f-5 + 8.3958956f-5im 1.24645085f-5 + 5.7054604f-5im; -2.0980988f-5 + 6.4780215f-5im -5.646849f-5 + 0.00011226714f0im … -3.4723133f-5 + 7.2328185f-5im -5.1170646f-5 + 2.281805f-5im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.2499015 -0.22190508 … -0.025331972 -0.18979827;;; 0.13305783 0.050857067 … -0.18063018 -0.28661627;;; 0.16043988 0.15182266 … -0.2461258 0.2570601;;; … ;;; 0.005296149 -0.040093254 … -0.15097226 -0.053225163;;; 0.10574898 0.08647882 … 0.007019982 0.091636375;;; 0.1282059 0.099630445 … 0.25117913 -0.042284876]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.08158102, -0.048859723, 0.09856976, -0.039282613, -0.019023782, -0.06293167, -0.0069825947, -0.13748856, 0.1509265, 0.07318273, 0.052235413, -0.07366915, -0.11031352, 0.15733778, -0.03522324, 0.054082416])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.26098025 0.0014223653 … -0.16509312 -0.24609666;;; -0.18191233 -0.05686772 … -0.34899518 -0.1069038;;; -0.11227317 -0.11872339 … -0.09954776 -0.20587848;;; … ;;; 0.07397507 -0.007055665 … -0.28904977 -0.2629392;;; -0.16595733 0.074639104 … -0.2049994 -0.17703444;;; 0.10318309 -0.2326732 … -0.3248463 -0.06614259]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.04862964, -0.24595603, 0.21566305, -0.010740787, 0.18549877, -0.24045777, -0.093390316, 0.007492125, 0.2196869, -0.170733 … -0.013793796, -0.24681365, -0.11400244, -0.22334975, 0.031818926, -0.1756441, 0.073583305, 0.034896523, -0.017550528, -0.018133432])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.30965522 -0.1243687 … 0.2946852 0.36146715]),)), layer_2 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.16249551 -0.15634416 … -0.16221799 -0.15392233;;; -0.13433377 -0.18748835 … -0.06660072 -0.18216184;;; 0.040300246 -0.08844461 … -0.054059085 -0.25375623;;; … ;;; -0.12836924 0.018244866 … -0.16251843 0.23288469;;; 0.22704746 -0.025933642 … -0.072192445 0.14716019;;; 0.26766166 -0.2636482 … -0.21693131 -0.14132255]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[-4.3213367f-7 + 9.552248f-6im -4.562433f-5 + 2.8591421f-5im … -1.5962723f-7 + 3.9006758f-5im -1.2995231f-5 + 5.046332f-5im; -4.532971f-5 + 6.63531f-5im 6.7751826f-6 + 3.0713672f-5im … 2.911353f-5 + 7.730858f-5im 1.0802178f-5 + 5.2523945f-5im; … ; -1.5265105f-6 + 1.6062077f-5im -4.0341452f-5 + 4.635089f-5im … 4.45833f-5 + 8.2617706f-5im 1.32281275f-5 + 4.2223808f-5im; -4.3970795f-7 + 3.7805454f-5im -2.9881681f-5 + 0.0001171348f0im … -5.014486f-5 + 3.3946708f-5im 5.7753525f-5 + 5.2192285f-5im;;; -2.5733447f-5 + 0.00012119749f0im 1.7441635f-5 + 4.317589f-5im … -1.8444509f-5 + 6.706009f-5im -1.6438105f-5 + 1.2321179f-5im; 5.5051525f-5 + 5.6276942f-5im 3.854686f-5 + 2.1495216f-5im … -2.6883623f-5 + 0.00011307937f0im -4.848603f-5 + 6.452185f-5im; … ; 5.2420684f-5 + 1.726026f-5im -2.0281746f-5 + 6.107267f-5im … -1.7607585f-5 + 8.800348f-5im 1.31585475f-5 + 0.00010280353f0im; -1.2317832f-6 + 1.8452323f-5im -4.695703f-5 + 5.4855765f-5im … -2.479164f-5 + 2.8350929f-5im -4.4507673f-5 + 3.6113037f-5im;;; 4.3966014f-5 + 4.144267f-5im 2.118785f-5 + 2.5340341f-6im … 3.6011123f-5 + 4.5488872f-5im 5.4747085f-5 + 0.000113253016f0im; -3.7302343f-5 + 4.1401727f-6im -5.487397f-5 + 9.6053685f-5im … 4.428535f-5 + 5.8777907f-5im 1.4524863f-5 + 8.880075f-5im; … ; 5.992251f-5 + 5.0793977f-5im 4.804857f-5 + 5.312819f-5im … 4.7774294f-5 + 8.040151f-6im 3.666385f-5 + 1.7012455f-5im; 1.9993531f-5 + 9.265911f-5im 3.191072f-5 + 2.098902f-5im … -2.9305884f-6 + 1.3219615f-6im 4.263912f-5 + 3.756446f-5im;;; … ;;; -1.1457698f-5 + 7.1276518f-6im 8.420931f-6 + 2.2076856f-5im … -5.9129074f-5 + 3.8685743f-5im -3.520189f-5 + 8.981501f-5im; 1.9237654f-5 + 2.7963666f-5im 2.2008855f-5 + 7.7140525f-5im … 9.407937f-6 + 4.110703f-5im 2.5603455f-5 + 0.00010762645f0im; … ; -2.0274805f-5 + 8.168896f-5im 1.8297033f-5 + 1.8206767f-5im … 5.0829178f-5 + 7.570877f-5im 5.380992f-5 + 8.101725f-5im; 3.0832714f-5 + 6.4734704f-5im -4.6758403f-5 + 0.00011270231f0im … -1.7255392f-5 + 9.144205f-7im -2.5990477f-5 + 7.4743744f-5im;;; -2.0271844f-5 + 8.531166f-5im -1.023661f-5 + 8.8702356f-5im … -4.225041f-5 + 0.00010233829f0im 8.978706f-6 + 6.258087f-5im; 1.9463609f-5 + 6.740637f-5im 2.0462205f-5 + 2.1281521f-5im … -5.8397825f-5 + 5.5497527f-5im -5.2319585f-5 + 1.7422026f-5im; … ; 6.2190957f-6 + 1.1587203f-5im -2.6921829f-5 + 4.648546f-5im … -1.457613f-5 + 0.00011957334f0im 5.332598f-5 + 3.5301491f-6im; 9.581403f-6 + 6.7065026f-5im -5.5380522f-5 + 4.195967f-5im … 4.511777f-5 + 3.372594f-5im -3.2406533f-6 + 8.140256f-5im;;; -3.9668303f-5 + 0.00011251539f0im 4.7135698f-5 + 3.5923105f-5im … 5.1722127f-5 + 7.065389f-5im -2.149614f-5 + 4.3209293f-6im; 2.6542832f-5 + 7.6705284f-5im 5.8016398f-5 + 6.8913636f-5im … -1.2363773f-5 + 4.6568057f-5im -4.0753126f-5 + 8.6382046f-5im; … ; 2.740611f-5 + 7.0424016f-5im 1.6304453f-5 + 4.479933f-5im … -1.3541445f-5 + 9.827596f-5im 4.6129433f-5 + 9.8354445f-5im; -4.2293148f-5 + 6.890383f-5im -5.7509547f-5 + 7.471714f-5im … 3.9893013f-5 + 6.6595545f-5im -5.8386067f-5 + 3.0882802f-7im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.19931217 -0.038853742 … 0.02519327 -0.27736554;;; 0.1380542 -0.3060923 … -0.2952975 -0.048586678;;; 0.053626705 0.18192576 … 0.15810376 0.08092775;;; … ;;; 0.2022158 -0.15004979 … 0.18929833 0.16817454;;; -0.3042661 0.18034095 … 0.12911078 0.2909972;;; 0.02907361 0.17837709 … 0.2977061 -0.249555]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.07667677, -0.086113006, -0.07049406, -0.085018605, -0.113993675, -0.1565345, 0.032358054, 0.07205596, 0.080442466, 0.09278492, -0.1648868, -0.042256873, 0.107203014, 0.06214287, -0.04586867, 0.14745176])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.40919682 -0.060771104 … 0.26594555 0.29448888;;; 0.2100923 -0.020880664 … 0.26790595 -0.37636316;;; 0.016274378 0.43003112 … 0.18857765 -0.042482086;;; … ;;; 0.43171984 -0.3223963 … -0.36206195 -0.28571203;;; 0.25123125 0.07295462 … 0.3921146 0.25063077;;; 0.18146217 0.15659842 … -0.38406086 0.055459134]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.18325835, 0.19020915, 0.10857296, -0.1706577, -0.011613369, -0.19638193, -0.08415708, -0.17053202, 0.037771165, -0.24617657 … -0.14497748, -0.14748621, -0.061487764, 0.23707029, -0.17547214, 0.120655715, -0.14699021, 0.24996626, -0.07024348, 0.22747377])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.30583805 0.30621213 … 0.23208699 0.009731903]),)), layer_3 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.10478004 -0.28206342 … 0.24759607 0.002577939;;; 0.19615252 -0.24351637 … 0.2517926 0.18234664;;; -0.087246716 -0.18155724 … 0.058029216 0.21798763;;; … ;;; 0.12183708 0.014451833 … -0.1044015 -0.042568043;;; 0.07385751 -0.241535 … -0.123617634 0.07366026;;; 0.026746137 0.011477246 … -0.10173632 -0.29467475]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[-4.5050845f-5 + 8.6928405f-5im -5.5727796f-5 + 0.000101038306f0im … 2.871282f-6 + 2.074962f-5im 2.1125648f-5 + 8.66439f-5im; 3.6826175f-5 + 9.857008f-5im 4.1872205f-5 + 3.915769f-5im … -5.572947f-5 + 9.202123f-5im 4.4727683f-5 + 4.6885107f-5im; … ; 3.4212135f-6 + 4.7746784f-5im 1.74874f-5 + 8.639792f-5im … -7.993272f-6 + 7.881536f-6im -8.268849f-6 + 0.0001094883f0im; -3.0433584f-6 + 0.00011712939f0im -9.699106f-6 + 8.709467f-5im … 5.728616f-5 + 0.000109068584f0im -5.040405f-5 + 5.0455812f-5im;;; -4.62276f-5 + 4.2766987f-6im 1.2143566f-5 + 5.8635356f-5im … 2.674674f-5 + 7.903576f-5im -3.167869f-5 + 8.087621f-5im; -2.2935586f-5 + 3.1616742f-5im 4.90249f-5 + 8.7502674f-5im … 4.6233305f-5 + 6.078626f-5im 9.75102f-7 + 8.673889f-5im; … ; 4.8049405f-6 + 8.1301754f-5im -2.8700466f-5 + 7.9702f-6im … 3.311318f-5 + 8.0470396f-5im 3.7980004f-5 + 8.6821026f-5im; 5.5981494f-5 + 8.182207f-6im 5.9870108f-5 + 8.466872f-5im … 3.6844132f-5 + 0.00010404025f0im 9.1513357f-7 + 0.00011276587f0im;;; -3.7895f-5 + 4.0327774f-5im -1.7285776f-5 + 9.9445024f-5im … -4.0393526f-5 + 4.429822f-5im -4.6572277f-5 + 1.4950761f-5im; -3.9223203f-5 + 4.3542022f-6im 5.249872f-5 + 2.8200739f-6im … -4.94333f-5 + 8.901625f-5im 1.9222934f-6 + 6.681224f-5im; … ; -1.6730526f-5 + 9.860847f-5im 5.2950672f-5 + 0.000106459665f0im … -1.8350285f-5 + 6.308601f-5im -6.0468585f-5 + 6.489226f-5im; 4.5636138f-5 + 8.832163f-5im 8.932082f-6 + 6.498421f-5im … 2.8376584f-5 + 0.00011607209f0im 3.4041208f-5 + 2.696098f-5im;;; … ;;; 3.8231177f-5 + 9.194218f-5im 5.687713f-5 + 3.4601515f-5im … 5.2306757f-5 + 6.040388f-5im -2.9953444f-6 + 8.899071f-5im; 1.7580584f-5 + 5.5207216f-5im -3.9512124f-5 + 2.6730726f-5im … 3.8029786f-5 + 9.575654f-5im 3.8371f-5 + 0.00011066941f0im; … ; 3.918758f-6 + 1.2647106f-5im 1.567688f-5 + 9.20217f-5im … -5.426989f-5 + 8.527198f-5im -1.6800674f-5 + 1.8577506f-5im; -4.756121f-5 + 5.9128295f-5im -3.1001415f-5 + 3.9793093f-5im … -5.591032f-5 + 4.643524f-5im -3.0354866f-5 + 8.5915985f-5im;;; 5.574582f-5 + 7.8629455f-5im -2.9328738f-5 + 4.183343f-5im … -1.0415053f-5 + 3.918537f-5im 5.4809585f-5 + 3.6847516f-5im; 1.0815842f-5 + 6.4005486f-5im -3.0247727f-5 + 6.313388f-5im … -5.4340286f-5 + 6.6076675f-5im -6.0724902f-5 + 7.622823f-5im; … ; -1.771516f-5 + 1.37061f-5im -5.042096f-5 + 5.127179f-5im … -1.7143742f-5 + 4.1893538f-5im -1.6970138f-5 + 1.5777397f-5im; -6.7528745f-6 + 4.1637068f-5im -2.6586356f-5 + 5.1070892f-6im … 1.8118139f-5 + 1.6145234f-5im 2.8257513f-5 + 9.27414f-5im;;; -3.6939164f-5 + 0.00011412186f0im 3.300456f-5 + 1.7080238f-5im … 4.89119f-7 + 7.3706833f-6im 5.6447083f-5 + 9.5056166f-5im; 1.5450525f-5 + 9.896232f-5im 3.6359022f-5 + 2.6300928f-5im … -8.956813f-6 + 0.00011278603f0im 2.4070476f-5 + 6.018691f-5im; … ; -4.312169f-7 + 0.00011714646f0im -4.4145374f-5 + 9.751506f-5im … 3.2971322f-5 + 8.987608f-5im -5.7998892f-5 + 6.268567f-5im; -5.4206364f-5 + 6.99894f-5im -2.5843787f-5 + 4.3684617f-5im … -2.2482513f-5 + 9.686773f-7im -2.9077557f-5 + 4.8695052f-5im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.17683722 0.2510872 … 0.03059914 -0.094213955;;; -0.008494629 -0.29086637 … 0.11521762 -0.28471395;;; -0.012370005 -0.12517634 … -0.06669554 0.30573472;;; … ;;; 0.17492537 -0.16424404 … -0.14785872 -0.16316386;;; 0.0010364609 0.077518925 … 0.14761318 0.2014482;;; 0.16241966 0.115103185 … 0.011449104 0.055790074]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.10035409, 0.038109034, -0.016236484, -0.11904253, 0.13967332, 0.14816351, -0.00023800325, -0.0002671267, 0.14482129, -0.020096628, 0.13527961, -0.022472994, 0.12987366, -0.043214556, 0.04240036, 0.16482179])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.25986698 0.26227286 … -0.12641506 -0.3840935;;; -0.11782553 0.2724455 … 0.40496734 -0.18310922;;; -0.09237967 0.21301927 … -0.23670512 0.43248188;;; … ;;; -0.060167935 0.2472041 … -0.26541725 0.009693041;;; -0.40251905 0.29627433 … 0.30677655 -0.4165416;;; -0.048550792 -0.23338632 … -0.14394379 -0.38246626]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.026846588, 0.080635846, 0.06683794, 0.16244626, -0.07815087, 0.152803, 0.16479272, -0.121367395, 0.2297239, 0.1897791 … -0.17452103, -0.24512008, 0.20470276, 0.21936709, -0.13549823, 0.17530629, 0.24047908, 0.03378281, -0.123776734, 0.23996893])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.18006153 -0.044710755 … 0.23704687 -0.33067575]),)), layer_4 = (layer_1 = (layer_1 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.04407919 0.1143419 … 0.061687637 -0.28850082;;; -0.067679666 0.01672755 … 0.048188206 -0.04970676;;; 0.19871809 -0.1634786 … 0.27249968 -0.07255701;;; … ;;; -0.2870409 -0.27912894 … -0.19491528 0.060188208;;; -0.1240586 -0.19982415 … 0.14165135 -0.18052104;;; -0.22594957 -0.105850376 … 0.228381 0.0525037]),), layer_2 = (stabilizer = NamedTuple(), conv_layer = (weight = Reactant.ConcretePJRTArray{ComplexF32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(ComplexF32[4.7109956f-5 + 6.844533f-5im -5.0223098f-5 + 0.00011177774f0im … 1.723141f-5 + 8.823347f-5im 4.2824613f-5 + 7.0632595f-6im; -3.8725906f-5 + 7.728201f-5im 5.1863535f-6 + 8.218684f-5im … 5.242967f-5 + 4.708505f-6im -4.43705f-5 + 8.343327f-5im; … ; -5.6599747f-6 + 0.00010460204f0im -1.6406382f-5 + 0.00011628997f0im … 4.82233f-5 + 6.330521f-5im 2.833797f-5 + 1.6703852f-6im; 2.1061816f-5 + 9.556735f-5im -1.9977946f-5 + 9.799366f-5im … 2.789183f-5 + 0.00010602723f0im 4.706975f-5 + 5.1235627f-5im;;; -1.7297061f-5 + 0.0001188012f0im 4.4021974f-5 + 9.0368405f-5im … 1.2859251f-5 + 3.8611033f-6im -5.7398065f-5 + 2.7580441f-5im; -2.2974098f-5 + 0.00012059308f0im 3.734744f-5 + 1.9102285f-5im … 5.504f-6 + 9.246193f-5im -5.004178f-5 + 6.464981f-5im; … ; -4.2404063f-6 + 1.3411904f-5im 8.938419f-6 + 4.721465f-5im … -2.7586117f-5 + 4.8724585f-5im -2.0679385f-5 + 6.838671f-5im; -7.8012745f-6 + 6.799843f-5im -3.3467506f-5 + 8.71617f-5im … -4.386261f-5 + 3.6020814f-5im -9.638992f-6 + 4.4832064f-5im;;; -8.011979f-6 + 8.363677f-6im 3.7042148f-5 + 2.0328262f-5im … -5.285351f-5 + 7.623405f-5im 5.0562427f-5 + 1.2302902f-5im; -2.3081717f-5 + 7.1108298f-6im -2.716932f-5 + 5.2246112f-5im … -4.0456966f-5 + 3.0914627f-5im -4.2323343f-5 + 6.564916f-5im; … ; 7.6652213f-7 + 3.1576543f-5im 3.202995f-5 + 7.958712f-6im … -2.3897235f-5 + 7.669332f-5im 4.8607348f-5 + 7.7641285f-5im; -2.9777795f-5 + 4.503275f-5im 4.338203f-5 + 8.828776f-5im … -4.0671875f-6 + 0.0001129253f0im 5.5829427f-5 + 3.0713913f-5im;;; … ;;; -3.879451f-5 + 9.3727926f-5im 4.4887827f-5 + 0.00011728365f0im … 4.9147842f-5 + 3.2064883f-5im 4.732637f-5 + 0.00011331052f0im; -1.7082872f-5 + 3.6433237f-5im 3.1465264f-5 + 0.0001212242f0im … -2.7363021f-5 + 4.93809f-5im -1.3653698f-6 + 8.427507f-5im; … ; -4.187272f-5 + 8.16155f-5im -2.7351416f-6 + 7.6831064f-5im … 1.9046442f-5 + 9.831698f-5im 2.4970133f-5 + 1.5951133f-5im; 2.2991306f-5 + 7.80153f-5im -3.4195546f-7 + 9.226947f-5im … -3.4755147f-5 + 0.000112651614f0im 5.3704855f-5 + 5.5889774f-5im;;; -4.3833053f-5 + 6.7845904f-6im -3.1529577f-5 + 8.4764906f-8im … -4.4153385f-5 + 4.1482737f-5im -1.0570409f-5 + 0.00010010913f0im; -1.8527615f-5 + 3.0667426f-5im -5.2286843f-5 + 8.6829823f-7im … -5.2805328f-5 + 0.00010639351f0im -3.898077f-5 + 7.215254f-5im; … ; 4.844251f-5 + 1.3695972f-5im -1.031335f-5 + 1.201102f-5im … -8.306233f-6 + 3.159533f-5im -2.4200192f-5 + 2.824638f-5im; -2.2020242f-5 + 3.34157f-5im -2.3599328f-5 + 4.1521314f-5im … 4.8695045f-5 + 9.874487f-5im -3.8911254f-5 + 9.969356f-5im;;; -2.2457323f-5 + 6.324066f-6im 6.089f-5 + 1.7201011f-5im … -1.904829f-5 + 0.00011875164f0im 4.0055325f-5 + 5.1004754f-6im; 2.573902f-5 + 5.54374f-5im 6.0623242f-5 + 2.0184336f-5im … -7.899034f-6 + 5.928843f-6im -1.9367399f-5 + 0.00011997295f0im; … ; 1.8636267f-5 + 8.665793f-5im -3.3702054f-5 + 5.6208897f-5im … -2.6682806f-5 + 2.5103436f-6im -1.2432385f-5 + 4.9553448f-5im; 5.445449f-5 + 4.45517f-5im -5.317966f-5 + 7.5942866f-5im … 2.586894f-6 + 1.5300684f-5im 5.0593102f-5 + 3.7455946f-5im]),))), layer_2 = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.18465999 -0.00897457 … 0.086088926 -0.3021894;;; -0.22334339 0.29962102 … -0.004637283 0.17493635;;; 0.091803 -0.16370289 … 0.26049474 -0.097626686;;; … ;;; -0.12856291 0.051715404 … -0.084493354 -0.20410877;;; -0.23245676 -0.02342779 … -0.2651699 0.11341385;;; -0.25004047 -0.024256675 … -0.29379335 -0.0813662]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.059194807, 0.02194753, 0.110233754, 0.15161894, 0.040310256, -0.07747433, 0.028117869, -0.027393386, -0.057636447, -0.009877008, -0.15005124, -0.018322626, -0.15934847, -0.15013382, 0.0037160877, 0.10888048])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.4058369 -0.2568665 … 0.30436164 0.32385364;;; 0.3757245 -0.3577952 … 0.38967735 0.010522458;;; -0.16101532 0.40828103 … 0.21681012 0.08709831;;; … ;;; 0.24412924 -0.24147178 … -0.31289682 0.26603356;;; 0.40247372 0.31719938 … -0.330145 0.12908614;;; 0.31658578 -0.24260736 … -0.2249615 0.041325815]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.12425375, -0.20775563, 0.015312433, 0.13685545, -0.13770595, 0.172425, -0.24542317, 0.078836024, -0.12780443, 0.028014153 … -0.026664883, 0.07899621, -0.06627256, -0.07442087, -0.08952406, 0.18337005, 0.17933294, -0.13792384, 0.21954033, 0.19697297])))), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.1733247 0.36929476 … 0.13656321 0.1472522]),))), projection = (layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.28594992 0.21611936 … -0.102160454 0.24119376;;; 0.04283227 -0.21876208 … 0.10653089 0.15687571]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.07121374, 0.013552151])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.8155725 1.0917363;;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[-0.25913045])))), (positional_embedding = NamedTuple(), lifting = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), fno_blocks = (layer_1 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple()), layer_2 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple()), layer_3 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple()), layer_4 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (stabilizer = NamedTuple(), conv_layer = NamedTuple())), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), layer_2 = NamedTuple())), projection = (layer_1 = NamedTuple(), layer_2 = NamedTuple())))
dataloader = DataLoader((x_data, y_data); batchsize=128, shuffle=true) |> xdev;
function train_model!(model, ps, st, dataloader; epochs=1000)
train_state = Training.TrainState(model, ps, st, Adam(0.0001f0))
for epoch in 1:epochs
loss = -Inf
for data in dataloader
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), MAELoss(), data, train_state; return_gradients=Val(false)
)
end
if epoch % 100 == 1 || epoch == epochs
@printf("Epoch %d: loss = %.6e\n", epoch, loss)
end
end
return train_state.parameters, train_state.states
end
ps_trained, st_trained = train_model!(fno, ps, st, dataloader)
Epoch 1: loss = 4.494606e-01
Epoch 101: loss = 2.216523e-02
Epoch 201: loss = 1.673523e-02
Epoch 301: loss = 1.283166e-02
Epoch 401: loss = 9.943001e-03
Epoch 501: loss = 6.147047e-03
Epoch 601: loss = 6.629670e-03
Epoch 701: loss = 6.406870e-03
Epoch 801: loss = 9.038369e-03
Epoch 901: loss = 6.250579e-03
Epoch 1000: loss = 5.358015e-03
using CairoMakie, AlgebraOfGraphics
const AoG = AlgebraOfGraphics
AoG.set_aog_theme!()
x_data_dev = x_data |> xdev;
y_data_dev = y_data |> xdev;
grid = range(0, 1; length=grid_size)
pred = first(
Reactant.with_config(;
convolution_precision=PrecisionConfig.HIGH,
dot_general_precision=PrecisionConfig.HIGH,
) do
@jit(fno(x_data_dev, ps_trained, st_trained))
end
) |> cdev
data_sequence, sequence, repeated_grid, label = Float32[], Int[], Float32[], String[]
for i in 1:16
append!(repeated_grid, repeat(grid, 2))
append!(sequence, repeat([i], grid_size * 2))
append!(label, repeat(["Ground Truth"], grid_size))
append!(label, repeat(["Predictions"], grid_size))
append!(data_sequence, vec(y_data[:, 1, i]))
append!(data_sequence, vec(pred[:, 1, i]))
end
plot_data = (; data_sequence, sequence, repeated_grid, label)
draw(
AoG.data(plot_data) *
mapping(
:repeated_grid => L"x",
:data_sequence => L"u(x)";
color=:label => "",
layout=:sequence => nonnumeric,
linestyle=:label => "",
) *
visual(Lines; linewidth=4),
scales(; Color=(; palette=:tab10), LineStyle = (; palette = [:solid, :dash]));
figure=(;
size=(1024, 1024),
title="Using FNO to solve the Burgers equation",
titlesize=25,
),
axis=(; xlabelsize=25, ylabelsize=25),
legend=(; label=L"u(x)", position=:bottom, labelsize=20),
)
