Complex Equations with PINNs

NeuralPDE supports training PINNs with complex differential equations. This example will demonstrate how to use it for NNODE. Let us consider a system of bloch equations[1]. Note QuadratureTraining cannot be used with complex equations due to current limitations of computing quadratures.

As the input to this neural network is time which is real, we need to initialize the parameters of the neural network with complex values for it to output and train with complex values.

using Random, NeuralPDE, OrdinaryDiffEq, Lux, OptimizationOptimisers, Plots
rng = Random.default_rng()
Random.seed!(100)

function bloch_equations(u, p, t)
    Ω, Δ, Γ = p
    γ = Γ / 2
    ρ₁₁, ρ₂₂, ρ₁₂, ρ₂₁ = u
    d̢ρ = [im * Ω * (ρ₁₂ - ρ₂₁) + Γ * ρ₂₂;
           -im * Ω * (ρ₁₂ - ρ₂₁) - Γ * ρ₂₂;
           -(γ + im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁);
           conj(-(γ + im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁))]
    return d̢ρ
end

u0 = zeros(ComplexF64, 4)
u0[1] = 1.0
time_span = (0.0, 2.0)
parameters = [2.0, 0.0, 1.0]

problem = ODEProblem(bloch_equations, u0, time_span, parameters)

chain = Chain(
    Dense(1, 16, tanh; init_weight = kaiming_normal(ComplexF64)),
    Dense(16, 4; init_weight = kaiming_normal(ComplexF64))
)
ps, st = Lux.setup(rng, chain)

opt = OptimizationOptimisers.Adam(0.01)
ground_truth = solve(problem, Tsit5(), saveat = 0.01)
alg = NNODE(chain, opt, ps; strategy = StochasticTraining(500))
sol = solve(problem, alg, verbose = false, maxiters = 5000, saveat = 0.01)
retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0:0.01:2.0
u: 201-element Vector{Vector{ComplexF64}}:
 [1.0 + 0.0im, 0.0 + 0.0im, 0.0 + 0.0im, 0.0 + 0.0im]
 [0.9992014155090517 + 0.0004810355914992892im, 0.0014576395640607549 - 0.0003193502226617151im, 0.00019476751383726754 + 0.020969313356579162im, -1.2127096858569609e-5 - 0.020640419712341013im]
 [0.9976127126259406 + 0.0007908625130735874im, 0.0035151579091619346 - 0.0005889698325035069im, 0.00031582145904003145 + 0.04144398214211857im, 2.7578171998721148e-6 - 0.04085302201463986im]
 [0.9952661117678354 + 0.0009562676193313618im, 0.006167999337208153 - 0.0008100116980966771im, 0.00037424722922642373 + 0.061445760703410665im, 3.725808038913869e-5 - 0.06065310455558278im]
 [0.9921911500339773 + 0.0010017441672715488im, 0.009411641922989702 - 0.000984444337626656im, 0.00038063061196346023 + 0.08099276111655768im, 8.439469815888501e-5 - 0.08005285022528119im]
 [0.9884150870674546 + 0.0009496824150481032im, 0.01324151342387643 - 0.0011149680201877627im, 0.000344996606096044 + 0.10009986203887283im, 0.00013766909582791365 - 0.09906167104165753im]
 [0.9839632516830503 + 0.0008205283127897722im, 0.01765290966233909 - 0.0012049281503999068im, 0.00027674894154080203 + 0.11877905906828201im, 0.00019112590785434812 - 0.11768650129247737im]
 [0.9788593379586287 + 0.0006329150026966804im, 0.022640915223697534 - 0.001258224931182687im, 0.00018461041749279563 + 0.13703976616484848im, 0.00023941568445568276 - 0.13593204799407713im]
 [0.9731256581116422 + 0.00040377083005052275im, 0.028200326328872226 - 0.0012792183527823632im, 7.65642253766445e-5 + 0.15488907572728486im, 0.0002778572926195766 - 0.1538010050675072im]
 [0.9667833584334504 + 0.0001484067327172492im, 0.0343255756906335 - 0.0012726276137490185im, -4.0203504350371736e-5 + 0.17233198327989183im, 0.0003024996514841122 - 0.17129423624116383im]
 ⋮
 [0.5704935959648827 - 0.03558792555314699im, 0.42830878062778327 + 0.027036618889326153im, 0.053754100770739464 + 0.20704285095872807im, 0.0791630353190704 - 0.2046823410637153im]
 [0.5666684828155846 - 0.035809247841719835im, 0.43245817213716337 + 0.027424507783848417im, 0.05483248807644218 + 0.20909281187488774im, 0.08016143326965831 - 0.2068276963835629im]
 [0.5627908889217776 - 0.03603601496628896im, 0.4367063858221525 + 0.0278137884348381im, 0.05594342574485589 + 0.21103187412079036im, 0.0811664249621754 - 0.2088812562022315im]
 [0.5588639022658681 - 0.03626863882128039im, 0.44105137164062874 + 0.028203486749690223im, 0.057088084280775 + 0.21285921405661873im, 0.08217871689056531 - 0.21084363566058598im]
 [0.5548905591801315 - 0.03650751852752212im, 0.4454910544694372 + 0.028592613556887633im, 0.05826759141378946 + 0.21457404008868067im, 0.08319906165065145 - 0.21271552760850623im]
 [0.5508738454983864 - 0.03675303979375349im, 0.45002333334533307 + 0.028980167969485392im, 0.059483029812662416 + 0.2161755906989216im, 0.08422825681557555 - 0.21449769948116992im]
 [0.5468166978099153 - 0.03700557432799695im, 0.4546460808371846 + 0.029365140638563794im, 0.06073543495053737 + 0.21766313269611587im, 0.08526714372923402 - 0.2161909903584259im]
 [0.542722004796405 - 0.03726547929889833im, 0.45935714255495097 + 0.029746516888073646im, 0.062025793114162966 + 0.21903595968855333im, 0.08631660622812168 - 0.2177963082058925im]
 [0.5385926086331264 - 0.03753309684718681im, 0.46415433680007967 + 0.03012327972355827im, 0.06335503955036563 + 0.2202933907773199im, 0.08737756930190399 - 0.21931462729523055im]

Now, let's plot the predictions.

u1:

plot(sol.t, real.(reduce(hcat, sol.u)[1, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[1, :]))
Example block output
plot(sol.t, imag.(reduce(hcat, sol.u)[1, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[1, :]))
Example block output

u2:

plot(sol.t, real.(reduce(hcat, sol.u)[2, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[2, :]))
Example block output
plot(sol.t, imag.(reduce(hcat, sol.u)[2, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[2, :]))
Example block output

u3:

plot(sol.t, real.(reduce(hcat, sol.u)[3, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[3, :]))
Example block output
plot(sol.t, imag.(reduce(hcat, sol.u)[3, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[3, :]))
Example block output

u4:

plot(sol.t, real.(reduce(hcat, sol.u)[4, :]));
plot!(ground_truth.t, real.(reduce(hcat, ground_truth.u)[4, :]))
Example block output
plot(sol.t, imag.(reduce(hcat, sol.u)[4, :]));
plot!(ground_truth.t, imag.(reduce(hcat, ground_truth.u)[4, :]))
Example block output

We can see it is able to learn the real parts of u1, u2 and imaginary parts of u3, u4.

  • 1https://steck.us/alkalidata/