Newton and Hessian-Free Newton-Krylov with Second Order Adjoint Sensitivity Analysis

In many cases it may be more optimal or more stable to fit using second order Newton-based optimization techniques. Since SciMLSensitivity.jl provides second order sensitivity analysis for fast Hessians and Hessian-vector products (via forward-over-reverse), we can utilize these in our neural/universal differential equation training processes.

sciml_train is set up to automatically use second order sensitivity analysis methods if a second order optimizer is requested via Optim.jl. Thus Newton and NewtonTrustRegion optimizers will use a second order Hessian-based optimization, while KrylovTrustRegion will utilize a Krylov-based method with Hessian-vector products (never forming the Hessian) for large parameter optimizations.

import SciMLSensitivity as SMS
import Lux
import ComponentArrays as CA
import Optimization as OPT
import OptimizationOptimisers as OPO
import OrdinaryDiffEq as ODE
import Plots
import Random
import OptimizationOptimJL as OOJ

u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end

prob_trueode = ODE.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(ODE.solve(prob_trueode, ODE.Tsit5(), saveat = tsteps))

dudt2 = Lux.Chain(x -> x .^ 3, Lux.Dense(2, 50, tanh), Lux.Dense(50, 2))
ps, st = Lux.setup(Random.default_rng(), dudt2)
function neuralodefunc(u, p, t)
    dudt2(u, p, st)[1]
end
function prob_neuralode(u0, p)
    prob = ODE.ODEProblem(neuralodefunc, u0, tspan, p)
    sol = ODE.solve(prob, ODE.Tsit5(), saveat = tsteps)
end
ps = CA.ComponentArray(ps)

function predict_neuralode(p)
    Array(prob_neuralode(u0, p))
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss
end

# Callback function to observe training
list_plots = []
iter = 0
callback = function (state, l; doplot = false)
    global list_plots, iter

    if iter == 0
        list_plots = []
    end
    iter += 1

    display(l)

    # plot current prediction against data
    pred = predict_neuralode(state.u)
    plt = Plots.scatter(tsteps, ode_data[1, :], label = "data")
    Plots.scatter!(plt, tsteps, pred[1, :], label = "prediction")
    push!(list_plots, plt)
    if doplot
        display(Plots.plot(plt))
    end

    return l < 0.01
end

adtype = OPT.AutoZygote()
optf = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)

optprob1 = OPT.OptimizationProblem(optf, ps)
pstart = OPT.solve(
    optprob1, OPO.Adam(0.01), callback = callback, maxiters = 100).u

optprob2 = OPT.OptimizationProblem(optf, pstart)
pmin = OPT.solve(optprob2, OOJ.NewtonTrustRegion(), callback = callback,
    maxiters = 200)
retcode: Failure
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.047799613 -1.6199665; 1.4023503 1.1084177; … ; -0.60066813 0.31390938; 1.8269526 1.0590281], bias = Float32[-0.7488036, 0.026995381, -0.38620678, 0.8672354, -1.0798581, 1.1549765, -0.039827425, 0.286886, 0.56903076, 0.7742217  …  0.22491257, 0.19492589, -0.5798681, -1.0526286, -0.15026408, 0.35852608, -0.6233194, 0.12951235, -0.5886757, -0.2196346]), layer_3 = (weight = Float32[0.53876525 -0.24971408 … 0.124115005 -0.20762716; 0.32070014 -0.0866468 … -0.21768445 0.3279887], bias = Float32[-0.5826545, -0.07100343]))

Note that we do not demonstrate Newton() because we have not found a single case where it is competitive with the other two methods. KrylovTrustRegion() is generally the fastest due to its use of Hessian-vector products.