NNODE(chain, opt=OptimizationPolyalgorithms.PolyOpt(), init_params = nothing; autodiff=false, batch=0, kwargs...)
Algorithm for solving ordinary differential equations using a neural network. This is a specialization of the physics-informed neural network which is used as a solver for a standard
Note that NNODE only supports ODEs which are written in the out-of-place form, i.e.
du = f(u,p,t), and not
f(du,u,p,t). If not declared out-of-place then the NNODE will exit with an error.
chain: A neural network architecture, defined as either a
opt: The optimizer to train the neural network. Defaults to
init_params: The initial parameter of the neural network. By default this is
nothingwhich thus uses the random initialization provided by the neural network library.
autodiff: The switch between automatic and numerical differentiation for the PDE operators. The reverse mode of the loss function is always automatic differentation (via Zygote), this is only for the derivative in the loss function (the derivative with respect to time).
batch: The batch size to use for the internal quadrature. Defaults to
0, which means the application of the neural network is done at individual time points one at a time.
batch>0means the neural network is applied at a row vector of values
tsimultaniously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
strategy: The training strategy used to choose the points for the evaluations. Default of
QuadratureTrainingwith QuadGK is used if no
dtis given, and
GridTrainingis used with
kwargs: Extra keyword arguments are splatted to the Optimization.jl
f(u,p,t) = cos(2pi*t) tspan = (0.0f0, 1.0f0) u0 = 0.0f0 prob = ODEProblem(linear, u0 ,tspan) chain = Flux.Chain(Dense(1,5,σ),Dense(5,1)) opt = Flux.ADAM(0.1) sol = solve(prob, NeuralPDE.NNODE(chain,opt), dt=1/20f0, verbose = true, abstol=1e-10, maxiters = 200)
Note that the solution is evaluated at fixed time points according to standard output handlers such as
dt. However, the neural network is a fully continuous solution so
sol(t) is an accuate interpolation (up to the neural network training result). In addition, the
OptimizationSolution is returned as
sol.k for further analysis.
Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000.