Using ahmc_bayesian_pinn_pde
with the BayesianPINN
Discretizer for the Kuramoto–Sivashinsky equation
Consider the Kuramoto–Sivashinsky equation:
\[∂_t u(x, t) + u(x, t) ∂_x u(x, t) + \alpha ∂^2_x u(x, t) + \beta ∂^3_x u(x, t) + \gamma ∂^4_x u(x, t) = 0 \, ,\]
where $\alpha = \gamma = 1$ and $\beta = 4$. The exact solution is:
\[u_e(x, t) = 11 + 15 \tanh \theta - 15 \tanh^2 \theta - 15 \tanh^3 \theta \, ,\]
where $\theta = t - x/2$ and with initial and boundary conditions:
\[\begin{align*} u( x, 0) &= u_e( x, 0) \, ,\\ u( 10, t) &= u_e( 10, t) \, ,\\ u(-10, t) &= u_e(-10, t) \, ,\\ ∂_x u( 10, t) &= ∂_x u_e( 10, t) \, ,\\ ∂_x u(-10, t) &= ∂_x u_e(-10, t) \, . \end{align*}\]
With Bayesian Physics-Informed Neural Networks, here is an example of using BayesianPINN
discretization with ahmc_bayesian_pinn_pde
:
using NeuralPDE, Lux, ModelingToolkit, LinearAlgebra, AdvancedHMC
import ModelingToolkit: Interval, infimum, supremum, Distributions
using Plots, MonteCarloMeasurements
@parameters x, t, α
@variables u(..)
Dt = Differential(t)
Dx = Differential(x)
Dx2 = Differential(x)^2
Dx3 = Differential(x)^3
Dx4 = Differential(x)^4
# α = 1
β = 4
γ = 1
eq = Dt(u(x, t)) + u(x, t) * Dx(u(x, t)) + α * Dx2(u(x, t)) + β * Dx3(u(x, t)) + γ * Dx4(u(x, t)) ~ 0
u_analytic(x, t; z = -x / 2 + t) = 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
du(x, t; z = -x / 2 + t) = 15 / 2 * (tanh(z) + 1) * (3 * tanh(z) - 1) * sech(z)^2
bcs = [u(x, 0) ~ u_analytic(x, 0),
u(-10, t) ~ u_analytic(-10, t),
u(10, t) ~ u_analytic(10, t),
Dx(u(-10, t)) ~ du(-10, t),
Dx(u(10, t)) ~ du(10, t)]
# Space and time domains
domains = [x ∈ Interval(-10.0, 10.0),
t ∈ Interval(0.0, 1.0)]
# Discretization
dx = 0.4;
dt = 0.2;
# Function to compute analytical solution at a specific point (x, t)
function u_analytic_point(x, t)
z = -x / 2 + t
return 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
end
# Function to generate the dataset matrix
function generate_dataset_matrix(domains, dx, dt)
x_values = -10:dx:10
t_values = 0.0:dt:1.0
dataset = []
for t in t_values
for x in x_values
u_value = u_analytic_point(x, t)
push!(dataset, [u_value, x, t])
end
end
return vcat([data' for data in dataset]...)
end
datasetpde = [generate_dataset_matrix(domains, dx, dt)]
# noise to dataset
noisydataset = deepcopy(datasetpde)
noisydataset[1][:, 1] = noisydataset[1][:, 1] .+
randn(size(noisydataset[1][:, 1])) .* 5 / 100 .*
noisydataset[1][:, 1]
306-element Vector{Float64}:
-3.952969340549449
-4.027440303058448
-3.9772268433729177
-3.98435035477215
-4.261773310984447
-4.473685896714991
-4.138419476382187
-3.8928847478138517
-3.828137107251296
-3.6077618736423287
⋮
-4.293977601399509
-3.9328815253543454
-4.199393713791813
-3.9176156473989603
-4.093502227189913
-4.1826946772166185
-3.9692979276193583
-3.8401965870814023
-4.231556742906129
Plotting dataset, added noise is set at 5%.
plot(datasetpde[1][:, 2], datasetpde[1][:, 1], title = "Dataset from Analytical Solution")
plot!(noisydataset[1][:, 2], noisydataset[1][:, 1])
# Neural network
chain = Chain(Dense(2, 8, tanh), Dense(8, 8, tanh), Dense(8, 1))
discretization = NeuralPDE.BayesianPINN([chain],
GridTraining([dx, dt]), param_estim = true, dataset = [noisydataset, nothing])
@named pde_system = PDESystem(eq,
bcs,
domains,
[x, t],
[u(x, t)],
[α],
defaults = Dict([α => 0.5]))
sol1 = ahmc_bayesian_pinn_pde(pde_system,
discretization;
draw_samples = 100, Kernel = AdvancedHMC.NUTS(0.8),
bcstd = [0.2, 0.2, 0.2, 0.2, 0.2],
phystd = [1.0], l2std = [0.05], param = [Distributions.LogNormal(0.5, 2)],
priorsNNw = (0.0, 10.0),
saveats = [1 / 100.0, 1 / 100.0], progress = true)
BPINNsolution{NeuralPDE.BPINNstats{MCMCChains.Chains{Float64, AxisArrays.AxisArray{Float64, 3, Base.ReshapedArray{Float64, 3, LinearAlgebra.Adjoint{Float64, Matrix{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, @NamedTuple{parameters::Vector{Symbol}}, @NamedTuple{}}, Vector{Vector{Float64}}, Vector{NamedTuple}}, Vector{Vector{MonteCarloMeasurements.Particles{Float64, 33}}}, Vector{ComponentArrays.ComponentVector{MonteCarloMeasurements.Particles{Float64, 34}, Vector{MonteCarloMeasurements.Particles{Float64, 34}}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, Shaped1DAxis((8,))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, Shaped1DAxis((8,))))), layer_3 = ViewAxis(97:105, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8))), bias = ViewAxis(9:9, Shaped1DAxis((1,))))))}}}}, Vector{MonteCarloMeasurements.Particles{Float64, 34}}, Vector{Matrix{Float64}}}(NeuralPDE.BPINNstats{MCMCChains.Chains{Float64, AxisArrays.AxisArray{Float64, 3, Base.ReshapedArray{Float64, 3, LinearAlgebra.Adjoint{Float64, Matrix{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:var, Vector{Symbol}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}, Missing, @NamedTuple{parameters::Vector{Symbol}}, @NamedTuple{}}, Vector{Vector{Float64}}, Vector{NamedTuple}}(MCMC chain (100×106×1 reshape(adjoint(::Matrix{Float64}), 100, 106, 1) with eltype Float64), [[0.1879015162287654, 1.5391239481908574, -1.5230363682719197, 0.6533298366477994, -0.9180775115831763, 1.7613906471220102, 1.237318438554521, 0.048705894221346635, -0.5392239868591001, -1.4795151647721214 … 0.28141386775104205, 0.4031614498496271, -0.3257727479525368, 0.2269822143171018, -0.044998845898146345, -0.45779743100269166, 0.14937836478207817, -0.2809783421736808, -0.037442961848344294, 0.5011137057570769], [0.1879015162287654, 1.5391239481908574, -1.5230363682719197, 0.6533298366477994, -0.9180775115831763, 1.7613906471220102, 1.237318438554521, 0.048705894221346635, -0.5392239868591001, -1.4795151647721214 … 0.28141386775104205, 0.4031614498496271, -0.3257727479525368, 0.2269822143171018, -0.044998845898146345, -0.45779743100269166, 0.14937836478207817, -0.2809783421736808, -0.037442961848344294, 0.5011137057570769], [0.1879015162287654, 1.5391239481908574, -1.5230363682719197, 0.6533298366477994, -0.9180775115831763, 1.7613906471220102, 1.237318438554521, 0.048705894221346635, -0.5392239868591001, -1.4795151647721214 … 0.28141386775104205, 0.4031614498496271, -0.3257727479525368, 0.2269822143171018, -0.044998845898146345, -0.45779743100269166, 0.14937836478207817, -0.2809783421736808, -0.037442961848344294, 0.5011137057570769], [0.22100112268168387, 1.5362311146164156, -1.5197693346169288, 0.6539569986113387, -0.9167865791450668, 1.7604681589704039, 1.2376957457452085, 0.041109685697428885, -0.5385778893179917, -1.4791251582774305 … 0.29250668877028785, 0.4045220121725932, -0.3265649092062184, 0.2195759551202105, -0.04723737809634741, -0.46622807397814486, 0.138940095389079, -0.27687905116164657, -0.04819302612277583, 0.501771531719308], [0.2660753082768163, 1.5313913026825199, -1.5144710859684678, 0.6545074979962257, -0.9134738933948876, 1.7591305023569193, 1.2367183725396307, 0.029579910651780457, -0.5372487565644427, -1.4798026425965856 … 0.3113104642635632, 0.4087269892208584, -0.3307761860082548, 0.20566315336995092, -0.05327538158451864, -0.4815462171636051, 0.12132172711845811, -0.2707968609719371, -0.06686049941606231, 0.5016713592818758], [0.3302055462678435, 1.5219425976254604, -1.5011755892923586, 0.6587256036303228, -0.9016509388110477, 1.7574140303449293, 1.232731800099624, 0.011138367321693187, -0.5398885867788636, -1.4811236772456031 … 0.35121162160906083, 0.42910273256062786, -0.351007140478781, 0.17640099041551058, -0.06851464960088509, -0.5178643138551928, 0.08406131858643377, -0.2606690290655661, -0.11035819679127823, 0.5014433643251667], [0.3505808501685124, 1.5163998025548744, -1.4909625493302685, 0.6611763398236898, -0.8908605233493393, 1.7553912588135479, 1.2268197876686235, 0.0029634409772288667, -0.5459937330709294, -1.4797722500180943 … 0.3758281243226544, 0.448359799837744, -0.37081220123451286, 0.1609803016152501, -0.07900105086828867, -0.5404981655283363, 0.06378790334131867, -0.2571828877014439, -0.139232934855709, 0.5022863183396976], [0.39266766892748983, 1.4993171416771118, -1.456163862636083, 0.6673206329178349, -0.8578551454080757, 1.7488600037931787, 1.2064095005498774, -0.00687764946226244, -0.5717976082559553, -1.479479311772998 … 0.4474857678604334, 0.5181182339021225, -0.44546102473469, 0.1219796976495412, -0.11293686013619317, -0.6124618881875011, 0.0014901869393174824, -0.2532561457060627, -0.22980652750201758, 0.5041799307974785], [0.39266766892748983, 1.4993171416771118, -1.456163862636083, 0.6673206329178349, -0.8578551454080757, 1.7488600037931787, 1.2064095005498774, -0.00687764946226244, -0.5717976082559553, -1.479479311772998 … 0.4474857678604334, 0.5181182339021225, -0.44546102473469, 0.1219796976495412, -0.11293686013619317, -0.6124618881875011, 0.0014901869393174824, -0.2532561457060627, -0.22980652750201758, 0.5041799307974785], [0.39264769596447696, 1.4990416749586206, -1.4556672812468172, 0.6674096106681757, -0.857471029415313, 1.7489651888339761, 1.2062051290039009, -0.0071588901728342905, -0.572029168393249, -1.4796733687380839 … 0.4480826949296566, 0.518754214564979, -0.44622093988631784, 0.12166587659315158, -0.11340163996870514, -0.6129932327615244, 0.0012704834957404025, -0.25357240155373056, -0.23054810166019568, 0.5041829725561686] … [0.7764421769968209, 0.9490028845828785, -0.46933173826909086, 0.6414495928263213, -0.4561822505576936, 0.5265975587822237, 0.8766192487800508, -0.43912010639383664, -1.740576454756957, -1.0800794378013947 … 4.150191172405242, 3.911196021286904, -3.5287689409159873, -1.005620617523413, -3.0144261651127153, -5.2366387924298285, -2.4012340617364965, -4.103992986271481, -0.010191790788061138, 1.0308633825933873], [0.766802497161261, 0.9437188613500921, -0.4687584137142801, 0.6278870612650483, -0.4754396538356588, 0.5323203837721983, 0.8677986245657408, -0.4480290777299194, -1.753666612855493, -1.1225166731612757 … 4.068896662880813, 3.885181802526949, -3.5638820628511496, -1.1044281380088685, -2.872883881360817, -5.216605694635497, -2.633324807349498, -3.882018052121437, -0.08296094120756357, 1.0403969513688585], [0.7278873719499817, 0.9127986366972448, -0.47325252561589054, 0.6184372468163393, -0.4979367663424314, 0.47084760868143893, 0.851179033087677, -0.38399902348331644, -1.7269843861693266, -1.1348167278498815 … 3.9000147362838806, 4.068263537265599, -3.7079387105997257, -1.0937375655604715, -2.5244106334284053, -4.794250209647722, -2.8152556049482755, -4.2023316838915905, -0.2152770455692272, 1.0374957136004928], [0.7536359689405294, 0.880968539552613, -0.4631118321739568, 0.6177388972288828, -0.48845157440143205, 0.4848590797419498, 0.8519320642653586, -0.4042773841304164, -1.731816171480279, -1.1635286415054256 … 4.05838683297359, 4.003076529770795, -3.6631665262458077, -1.0328048165056, -2.4357864965995306, -5.0062785919168356, -2.9191803185128595, -4.2205086688936335, -0.19429265935112608, 1.0371168318159942], [0.7182534094195647, 0.8474550999847111, -0.46439196148610273, 0.5899935509372665, -0.4375174741383703, 0.4519911360885513, 0.7909735811262313, -0.388591260801855, -1.7160975572138175, -1.0612377954240755 … 4.093238847657713, 4.471853658656969, -3.7370534569787197, -0.7588021319055174, -2.7587702990606173, -5.293079702487971, -3.2098717667384817, -4.290662088053534, 0.20492796244389336, 1.0441849776731835], [0.7194845917893533, 0.8412199028429964, -0.4640954767338214, 0.6035865424326062, -0.43890887202568424, 0.45521204912544355, 0.7742891263383453, -0.38986674090924367, -1.6760449607993924, -1.1236106680814035 … 4.04933410230552, 4.530995910442574, -3.6927178414513264, -0.7758425763687952, -2.742349429619605, -5.26604604953369, -3.2792300788820063, -4.306142969792554, 0.18049336781400457, 1.0466038018643202], [0.7943900991348435, 0.8200926146266595, -0.46832251704623695, 0.6742893753567543, -0.3732384385959507, 0.4588781775242564, 0.8336074610008897, -0.3649805585950613, -1.8093679604830775, -1.1451581429918394 … 4.021399255400682, 4.605799845410516, -3.6359458741354156, -0.7595093999622758, -2.7821677127725275, -5.192430764343883, -3.3370564537015226, -4.239889675986914, 0.15680945167225124, 1.0576943513661794], [0.7538888851713682, 0.8342291281957449, -0.4554818766825455, 0.6176509290020314, -0.3651268253211717, 0.48218821167842585, 0.8732783701194998, -0.3348511957957219, -1.6902616687502048, -1.183321736425363 … 3.8241895457527364, 4.670294488835907, -3.4022891911270214, -0.7594992899094166, -2.908063197807243, -5.433326113265443, -3.3234165989699744, -4.13569002211516, 0.08756925166913081, 1.0406846343145002], [0.7636286034041977, 0.8417502266920324, -0.4622791489903231, 0.6320661726991047, -0.3388002797079429, 0.4711931941592111, 0.8751016642524425, -0.31018383470741717, -1.6664229773079526, -1.1342993024388976 … 3.865743330093639, 4.584401661596238, -3.3729206915889502, -0.7358938654905643, -3.0298828479292834, -5.392484006688876, -3.2301158363508504, -4.185541769235919, 0.06282829653203346, 1.034740615944214], [0.7351568384509087, 0.7818673143480283, -0.4670819254792557, 0.5645907866934298, -0.30731342683952917, 0.3637541094155943, 0.862370065876331, -0.2540966526038482, -1.7264320906495991, -1.1775751657496418 … 3.785273250355198, 4.520666684958718, -2.9576165129541745, -0.6892846421490094, -2.70262266642754, -5.384303701752045, -3.0620040833559994, -4.461171935768222, -0.155118890009104, 1.0420034146669541]], NamedTuple[(n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -1.7646844910160399e6, hamiltonian_energy = 1.9460811456451332e6, hamiltonian_energy_error = -65222.86156117753, max_hamiltonian_energy_error = -65222.86156117753, tree_depth = 2, numerical_error = false, step_size = 0.00078125, nom_step_size = 0.00078125, is_adapt = true), (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -1.7646844910160399e6, hamiltonian_energy = 1.7647393654108879e6, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 4.6122931835972725e23, tree_depth = 0, numerical_error = true, step_size = 0.011238679762325607, nom_step_size = 0.011238679762325607, is_adapt = true), (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -1.7646844910160399e6, hamiltonian_energy = 1.7647233036830667e6, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 2.394904956626255e6, tree_depth = 0, numerical_error = true, step_size = 0.0018993494877672993, nom_step_size = 0.0018993494877672993, is_adapt = true), (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -1.739003369477052e6, hamiltonian_energy = 1.7642248716506623e6, hamiltonian_energy_error = -524.859684261959, max_hamiltonian_energy_error = -524.859684261959, tree_depth = 2, numerical_error = false, step_size = 0.00018733703557689892, nom_step_size = 0.00018733703557689892, is_adapt = true), (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -1.7069133713486404e6, hamiltonian_energy = 1.7382763019664192e6, hamiltonian_energy_error = -772.7463671674486, max_hamiltonian_energy_error = -772.7463671674486, tree_depth = 2, numerical_error = false, step_size = 0.00025338469449058976, nom_step_size = 0.00025338469449058976, is_adapt = true), (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -1.6574696274102961e6, hamiltonian_energy = 1.705794068996543e6, hamiltonian_energy_error = -1179.440884117037, max_hamiltonian_energy_error = -1179.440884117037, tree_depth = 2, numerical_error = false, step_size = 0.0003962543583196465, nom_step_size = 0.0003962543583196465, is_adapt = true), (n_steps = 1, is_accept = true, acceptance_rate = 1.0, log_density = -1.6300116138676293e6, hamiltonian_energy = 1.6571134817106451e6, hamiltonian_energy_error = -410.2126577729359, max_hamiltonian_energy_error = -410.2126577729359, tree_depth = 1, numerical_error = false, step_size = 0.0006745127317215788, nom_step_size = 0.0006745127317215788, is_adapt = true), (n_steps = 2, is_accept = true, acceptance_rate = 0.5, log_density = -1.5460395128536725e6, hamiltonian_energy = 1.629498217975466e6, hamiltonian_energy_error = -561.1801491740625, max_hamiltonian_energy_error = 2033.311756824376, tree_depth = 1, numerical_error = true, step_size = 0.00120699748002278, nom_step_size = 0.00120699748002278, is_adapt = true), (n_steps = 3, is_accept = true, acceptance_rate = 2.6453147358037707e-15, log_density = -1.5460395128536725e6, hamiltonian_energy = 1.546089142983631e6, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 159.2543622259982, tree_depth = 2, numerical_error = false, step_size = 0.00046176364501528296, nom_step_size = 0.00046176364501528296, is_adapt = true), (n_steps = 7, is_accept = true, acceptance_rate = 0.9679564098101096, log_density = -1.545314803501922e6, hamiltonian_energy = 1.54608877577322e6, hamiltonian_energy_error = 0.0037657031789422035, max_hamiltonian_energy_error = 0.07396192592568696, tree_depth = 3, numerical_error = false, step_size = 3.64178573693074e-5, nom_step_size = 3.64178573693074e-5, is_adapt = true) … (n_steps = 1023, is_accept = true, acceptance_rate = 0.8994924053203376, log_density = -4044.6804483253286, hamiltonian_energy = 4092.844653187668, hamiltonian_energy_error = 0.271856519926132, max_hamiltonian_energy_error = 0.2811952331539942, tree_depth = 10, numerical_error = false, step_size = 0.0002655487926798032, nom_step_size = 0.0002655487926798032, is_adapt = false), (n_steps = 1023, is_accept = true, acceptance_rate = 0.5449985134698874, log_density = -4049.502270842527, hamiltonian_energy = 4104.343950779468, hamiltonian_energy_error = 1.123540809527185, max_hamiltonian_energy_error = 1.8602536979615252, tree_depth = 10, numerical_error = false, step_size = 0.0002655487926798032, nom_step_size = 0.0002655487926798032, is_adapt = false), (n_steps = 1023, is_accept = true, acceptance_rate = 0.6923927199458352, log_density = -4039.760069112025, hamiltonian_energy = 4110.809197425149, hamiltonian_energy_error = -0.9936150624771471, max_hamiltonian_energy_error = 1.8952188860957904, tree_depth = 10, numerical_error = false, step_size = 0.0002655487926798032, nom_step_size = 0.0002655487926798032, is_adapt = false), (n_steps = 1023, is_accept = true, acceptance_rate = 0.866948206008506, log_density = -4051.1046386504677, hamiltonian_energy = 4097.0028633449265, hamiltonian_energy_error = 0.11899383030140598, max_hamiltonian_energy_error = 0.501012108094983, tree_depth = 10, numerical_error = false, step_size = 0.0002655487926798032, nom_step_size = 0.0002655487926798032, is_adapt = false), (n_steps = 1023, is_accept = true, acceptance_rate = 0.6928672851463703, log_density = -4029.9364448059923, hamiltonian_energy = 4099.755093589031, hamiltonian_energy_error = 0.18971839865480433, max_hamiltonian_energy_error = 1.4112061922096473, tree_depth = 10, numerical_error = false, step_size = 0.0002655487926798032, nom_step_size = 0.0002655487926798032, is_adapt = false), (n_steps = 1023, is_accept = true, acceptance_rate = 0.44955843950247626, log_density = -4035.469384028512, hamiltonian_energy = 4084.272353917529, hamiltonian_energy_error = 0.04210268172300857, max_hamiltonian_energy_error = 3.5956495844334313, tree_depth = 10, numerical_error = false, step_size = 0.0002655487926798032, nom_step_size = 0.0002655487926798032, is_adapt = false), (n_steps = 1023, is_accept = true, acceptance_rate = 0.5029640720511499, log_density = -4034.1077480840295, hamiltonian_energy = 4089.032157689568, hamiltonian_energy_error = -0.9262744536395076, max_hamiltonian_energy_error = 3.6218168853447423, tree_depth = 10, numerical_error = false, step_size = 0.0002655487926798032, nom_step_size = 0.0002655487926798032, is_adapt = false), (n_steps = 1023, is_accept = true, acceptance_rate = 0.8970484580942492, log_density = -4023.7317325631616, hamiltonian_energy = 4098.564338821594, hamiltonian_energy_error = 0.00983104055376316, max_hamiltonian_energy_error = 0.4392998711200562, tree_depth = 10, numerical_error = false, step_size = 0.0002655487926798032, nom_step_size = 0.0002655487926798032, is_adapt = false), (n_steps = 1023, is_accept = true, acceptance_rate = 0.4467600089662101, log_density = -4026.4473601608597, hamiltonian_energy = 4071.7360933479426, hamiltonian_energy_error = 1.2762813097288017, max_hamiltonian_energy_error = 2.251764342271599, tree_depth = 10, numerical_error = false, step_size = 0.0002655487926798032, nom_step_size = 0.0002655487926798032, is_adapt = false), (n_steps = 1023, is_accept = true, acceptance_rate = 0.783729225191403, log_density = -4000.16829046935, hamiltonian_energy = 4079.86406679174, hamiltonian_energy_error = 0.7280015213596016, max_hamiltonian_energy_error = 1.781684632557699, tree_depth = 10, numerical_error = false, step_size = 0.0002655487926798032, nom_step_size = 0.0002655487926798032, is_adapt = false)]), Vector{MonteCarloMeasurements.Particles{Float64, 33}}[[-4.01 ± 0.0058, -4.01 ± 0.0058, -4.01 ± 0.0058, -4.01 ± 0.0057, -4.01 ± 0.0057, -4.01 ± 0.0057, -4.01 ± 0.0057, -4.01 ± 0.0057, -4.01 ± 0.0057, -4.01 ± 0.0057 … -4.02 ± 0.012, -4.02 ± 0.012, -4.02 ± 0.012, -4.02 ± 0.012, -4.02 ± 0.012, -4.02 ± 0.012, -4.02 ± 0.012, -4.02 ± 0.012, -4.02 ± 0.012, -4.02 ± 0.012]], ComponentArrays.ComponentVector{MonteCarloMeasurements.Particles{Float64, 34}, Vector{MonteCarloMeasurements.Particles{Float64, 34}}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, Shaped1DAxis((8,))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, Shaped1DAxis((8,))))), layer_3 = ViewAxis(97:105, Axis(weight = ViewAxis(1:8, ShapedAxis((1, 8))), bias = ViewAxis(9:9, Shaped1DAxis((1,))))))}}}[(layer_1 = (weight = MonteCarloMeasurements.Particles{Float64, 34}[0.787 ± 0.034 -1.69 ± 0.11; 0.86 ± 0.043 -1.24 ± 0.16; … ; 0.876 ± 0.044 -1.64 ± 0.08; -0.317 ± 0.099 -0.32 ± 0.45], bias = MonteCarloMeasurements.Particles{Float64, 34}[-0.487 ± 0.097, -0.293 ± 0.14, -0.674 ± 0.098, 0.00816 ± 0.15, 0.882 ± 0.47, -1.73 ± 0.49, 0.146 ± 0.093, 1.37 ± 0.61]), layer_2 = (weight = MonteCarloMeasurements.Particles{Float64, 34}[-0.357 ± 0.17 -0.791 ± 0.11 … -0.161 ± 0.13 0.456 ± 0.38; -0.826 ± 0.27 -0.474 ± 0.34 … -1.18 ± 0.28 -0.164 ± 0.21; … ; 0.785 ± 0.15 -0.794 ± 0.31 … 0.9 ± 0.16 -1.0 ± 0.18; 0.787 ± 0.39 0.217 ± 0.19 … -0.542 ± 0.17 0.292 ± 0.34], bias = MonteCarloMeasurements.Particles{Float64, 34}[-0.291 ± 0.34, -0.542 ± 0.2, -0.279 ± 0.22, -0.739 ± 0.44, -1.17 ± 0.65, 0.933 ± 0.27, 0.502 ± 0.27, -0.915 ± 0.19]), layer_3 = (weight = MonteCarloMeasurements.Particles{Float64, 34}[4.06 ± 0.26 4.17 ± 0.28 … -2.69 ± 0.36 -3.65 ± 0.5], bias = MonteCarloMeasurements.Particles{Float64, 34}[-0.178 ± 0.28]))], MonteCarloMeasurements.Particles{Float64, 34}[1.04 ± 0.012], [[-10.0 -9.99 … 9.99 10.0; 0.0 0.0 … 1.0 1.0]])
And some analysis:
phi = discretization.phi[1]
xs, ts = [infimum(d.domain):dx:supremum(d.domain)
for (d, dx) in zip(domains, [dx / 10, dt])]
u_predict = [[first(pmean(phi([x, t], sol1.estimated_nn_params[1]))) for x in xs]
for t in ts]
u_real = [[u_analytic(x, t) for x in xs] for t in ts]
diff_u = [[abs(u_analytic(x, t) - first(pmean(phi([x, t], sol1.estimated_nn_params[1]))))
for x in xs]
for t in ts]
p1 = plot(xs, u_predict, title = "predict")
p2 = plot(xs, u_real, title = "analytic")
p3 = plot(xs, diff_u, title = "error")
plot(p1, p2, p3)