Differentiating an ODE Solution with Automatic Differentiation

Note

This tutorial assumes familiarity with DifferentialEquations.jl If you are not familiar with DifferentialEquations.jl, please consult the DifferentialEquations.jl documentation

In this tutorial we will introduce how to use local sensitivity analysis via automatic differentiation. The automatic differentiation interfaces are the most common ways that local sensitivity analysis is done. It's fairly fast and flexible, but most notably, it's a very small natural extension to the normal differential equation solving code and is thus the easiest way to do most things.

Setup

Let's first define a differential equation we wish to solve. We will choose the Lotka-Volterra equation. This is done via DifferentialEquations.jl using:

using DifferentialEquations

function lotka_volterra!(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
prob = ODEProblem(lotka_volterra!,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5(),reltol=1e-6,abstol=1e-6)
retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 104-element Vector{Float64}:
  0.0
  0.02238867177415836
  0.06688455734042167
  0.12204057917794937
  0.19017391008388496
  0.2700958857748864
  0.3624899576395024
  0.4663498927011971
  0.5804932271654899
  0.7035670605083028
  ⋮
  9.363458964140799
  9.43825400337173
  9.514924336844068
  9.594877367879974
  9.679331592925825
  9.769895515116275
  9.868269593442548
  9.975570677585456
 10.0
u: 104-element Vector{Vector{Float64}}:
 [1.0, 1.0]
 [1.0117558257818347, 0.9563342092954508]
 [1.0384182069752141, 0.8758683256119992]
 [1.0774848852893388, 0.786875166302778]
 [1.1349057840697077, 0.6915813145776512]
 [1.2153494328034518, 0.5976695389806336]
 [1.3266197077630641, 0.5093485173645297]
 [1.4766110917363853, 0.4313359873619578]
 [1.67467230359698, 0.36648541839701415]
 [1.93171526930104, 0.31613609766341083]
 ⋮
 [1.2804916562483988, 3.2111899168482525]
 [1.1439254396494953, 2.8083593976895913]
 [1.050250712325505, 2.4264943975879096]
 [0.9895322478849293, 2.070758237526672]
 [0.9563824177295706, 1.744601626564735]
 [0.9484176857475457, 1.4490308503325875]
 [0.9660834273481814, 1.1849910640094092]
 [1.0122117036820808, 0.9547854523135569]
 [1.0263542618072086, 0.9096831916611162]

Now let's differentiate the solution to this ODE using a few different automatic differentiation methods.

Forward-Mode Automatic Differentiation with ForwardDiff.jl

Let's say we need the derivative of the solution with respect to the initial condition u0 and its parameters p. One of the simplest ways to do this is via ForwardDiff.jl. To do this, all that one needs to do is use the ForwardDiff.jl library to differentiate some function f which uses a differential equation solve inside of it. For example, let's say we want the derivative of the first component of ODE solution with respect to these quantities at evenly spaced time points of dt = 1. We can compute this via:

using ForwardDiff

function f(x)
    _prob = remake(prob,u0=x[1:2],p=x[3:end])
    solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=1)[1,:]
end
x = [u0;p]
dx = ForwardDiff.jacobian(f,x)

Let's dig into what this is saying a bit. x is a vector which concatenates the initial condition and parameters, meaning that the first 2 values are the initial conditions and the last 4 are the parameters. We use the remake function to build a function f(x) which uses these new initial conditions and parameters to solve the differential equation and return the time series of the first component.

Then ForwardDiff.jacobian(f,x) computes the Jacobian of f with respect to x. The output dx[i,j] corresponds to the derivative of the solution of the first component at time t=j-1 with respect to x[i]. For example, dx[3,2] is the derivative of the first component of the solution at time t=1 with respect to p[1].

Note

Since the global error is 1-2 orders of magnitude higher than the local error, we use accuracies of 1e-6 (instead of the default 1e-3) to get reasonable sensitivities

Reverse-Mode Automatic Differentiation

The solve function is automatically compatible with AD systems like Zygote.jl and thus there is no machinery that is necessary to use other than to put solve inside of a function that is differentiated by Zygote. For example, the following computes the solution to an ODE and computes the gradient of a loss function (the sum of the ODE's output at each timepoint with dt=0.1) via the adjoint method:

using Zygote, SciMLSensitivity

function sum_of_solution(u0,p)
  _prob = remake(prob,u0=u0,p=p)
  sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1))
end
du01,dp1 = Zygote.gradient(sum_of_solution,u0,p)

Zygote.jl's automatic differentiation system is overloaded to allow SciMLSensitivity.jl to redefine the way the derivatives are computed, allowing trade-offs between numerical stability, memory, and compute performance, similar to how ODE solver algorithms are chosen. The algorithms for differentiation calculation are called AbstractSensitivityAlgorithms, or sensealgs for short. These are choosen by passing the sensealg keyword argument into solve.

Let's demonstrate this by choosing the QuadratureAdjoint sensealg for the differentiation of this system:

function sum_of_solution(u0,p)
  _prob = remake(prob,u0=u0,p=p)
  sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1,sensealg=QuadratureAdjoint()))
end
du01,dp1 = Zygote.gradient(sum_of_solution,u0,p)

Here this computes the derivative of the output with respect to the initial condition and the the derivative with respect to the parameters respectively using the QuadratureAdjoint(). For more information on the choices of sensitivity algorithms, see the reference documentation in choosing sensitivity algorithms

When Should You Use Forward or Reverse Mode?

Good question! The simple answer is, if you are differentiating a system of 100 equations or less, use forward-mode, otherwise reverse-mode. But it can be a lot more complicated than that! For more information, see the reference documentation in choosing sensitivity algorithms