Direct Sensitivity Analysis Functionality

While sensitivity analysis tooling can be used implicitly via integration with automatic differentiation libraries, one can often times obtain more speed and flexibility with the direct sensitivity analysis interfaces. This tutorial demonstrates some of those functions.

Example using an ODEForwardSensitivityProblem

Forward sensitivity analysis is performed by defining and solving an augmented ODE. To define this augmented ODE, use the ODEForwardSensitivityProblem type instead of an ODE type. For example, we generate an ODE with the sensitivity equations attached for the Lotka-Volterra equations by:

using OrdinaryDiffEq, SciMLSensitivity

function f(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + u[1]*u[2]
end

p = [1.5,1.0,3.0]
prob = ODEForwardSensitivityProblem(f,[1.0;1.0],(0.0,10.0),p)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 10.0)
u0: 8-element Vector{Float64}:
 1.0
 1.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

This generates a problem which the ODE solvers can solve:

sol = solve(prob,DP8())
retcode: Success
Interpolation: specialized 7th order interpolation
t: 29-element Vector{Float64}:
  0.0
  0.0008156803234081559
  0.005709762263857092
  0.0350742539065507
  0.21126120376271237
  0.7310736576107115
  1.540222712617339
  1.8813610466661779
  2.152579392464959
  2.4063311841696478
  ⋮
  7.063355055637446
  7.725935939669195
  8.248979441432317
  8.558003582514011
  8.826370842927059
  9.171011754187145
  9.493946497917326
  9.834929283472757
 10.0
u: 29-element Vector{Vector{Float64}}:
 [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
 [1.000408588538797, 0.9983701355642965, 0.000816013510644824, 3.322154010506493e-7, -0.0008153483114736799, -3.320348233838799e-7, 3.3244140163739216e-7, -0.0008143507848015046]
 [1.0028915156848317, 0.9886535575433295, 0.005726241237749595, 1.6146661657253676e-5, -0.005693685298331232, -1.6085381411877104e-5, 1.622392391840203e-5, -0.005644946211107334]
 [1.0189121274128703, 0.9325571368196612, 0.035730565614057214, 0.0005806610224735032, -0.034509633035320036, -0.0005673264615114402, 0.0005982168608307318, -0.03270217921302392]
 [1.1547444753248808, 0.6650474487943194, 0.24252997273953153, 0.01620058797923608, -0.198498945739975, -0.014142604007354084, 0.019606496433572655, -0.13955624899729338]
 [1.9959026411822718, 0.30725094819300236, 1.3911332658550328, 0.12370946057242008, -0.7599570019860626, -0.07996649485520504, 0.22938885567774317, -0.20775306589835096]
 [5.30194567946053, 0.4275038176301537, 6.007173244666811, 1.378441778030569, -2.283870425183431, -0.623462248893163, 1.6987114297214083, -0.3708646913283792]
 [6.8702723310901295, 1.2712401591810842, 2.482449626359105, 6.439151700037799, -1.3710505647874744, -2.7849768648683777, 3.2715017511539424, -0.46731359361629526]
 [5.679446513843018, 3.3399703081159786, -12.310270279196507, 12.730860582111236, 2.90563205459914, -6.735838408409488, 2.6735629496567994, 0.8629101364446949]
 [2.8838282002507136, 4.57403857987998, -11.50192475754792, 1.5722474815418457, 3.149560282389645, -5.129125042632567, 0.2620888410181659, 1.5914280086423986]
 ⋮
 [1.4188559373191922, 0.45747117721990466, 4.0345730575991805, -1.6292340104788547, -0.02328531538568726, -0.22674039853297237, 0.9917652543963326, -0.638337264202254]
 [3.105286408605591, 0.25723901264027416, 12.030706505671914, 0.3657288649741977, -0.3594670814358994, -0.1562962149538488, 2.9860724103159155, -0.2135072543485681]
 [5.721950106631197, 0.5147791449780601, 19.322713832885665, 5.150407774561523, -0.8979267203135737, -0.4767887879979121, 5.4868266280623645, 0.4525830529452884]
 [6.9045644791116585, 1.4959711323293163, 0.8111460438513347, 21.257023017125366, -0.885529900291283, -1.837556476410295, 3.431317057907666, 3.210786998858488]
 [5.251408633416557, 3.677081894045012, -40.30711600763158, 31.383646964911545, 0.415505603414129, -4.822192363404264, -4.808988602487445, 6.282699534622149]
 [1.9684993733166822, 4.224446642146764, -19.72010623548457, -13.790594696972407, 0.6970930173849572, -4.432962960251689, -3.4673013009932574, -1.767994652510742]
 [1.0719720028424078, 2.5266939022673727, -4.294296713681639, -16.729797806817487, 0.3407022900162246, -2.2485501409033666, -0.8127688495258327, -3.42797696441485]
 [0.9574127494597999, 1.2680817248527774, 0.6626271434785644, -9.021542819507147, 0.21249816908491037, -1.0143392771168123, 0.21400382117613845, -2.2552162974026566]
 [1.026505547286025, 0.9095251254958595, 2.1626974566892803, -6.256489916332188, 0.1883893214244694, -0.6976152811241049, 0.5638188536930522, -1.7090441864606456]

Note that the solution is the standard ODE system and the sensitivity system combined. We can use the following helper functions to extract the sensitivity information:

x,dp = extract_local_sensitivities(sol)
x,dp = extract_local_sensitivities(sol,i)
x,dp = extract_local_sensitivities(sol,t)

In each case, x is the ODE values and dp is the matrix of sensitivities The first gives the full timeseries of values and dp[i] contains the time series of the sensitivities of all components of the ODE with respect to ith parameter. The second returns the ith time step, while the third interpolates to calculate the sensitivities at time t. For example, if we do:

x,dp = extract_local_sensitivities(sol)
da = dp[1]
2×29 Matrix{Float64}:
 0.0  0.000816014  0.00572624  0.0357306    …   -4.2943   0.662627   2.1627
 0.0  3.32215e-7   1.61467e-5  0.000580661     -16.7298  -9.02154   -6.25649

then da is the timeseries for $\frac{\partial u(t)}{\partial p}$. We can plot this

using Plots
plot(sol.t,da',lw=3)

transposing so that the rows (the timeseries) is plotted.

Local Sensitivity Solution

For more information on the internal representation of the ODEForwardSensitivityProblem solution, see the direct forward sensitivity analysis manual page.

Example using adjoint_sensitivities for discrete adjoints

In this example we will show solving for the adjoint sensitivities of a discrete cost functional. First let's solve the ODE and get a high quality continuous solution:

function f(du,u,p,t)
  du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
  du[2] = dy = -p[3]*u[2] + u[1]*u[2]
end

p = [1.5,1.0,3.0]
prob = ODEProblem(f,[1.0;1.0],(0.0,10.0),p)
sol = solve(prob,Vern9(),abstol=1e-10,reltol=1e-10)
retcode: Success
Interpolation: specialized 9th order lazy interpolation
t: 90-element Vector{Float64}:
  0.0
  0.04354162443296621
  0.12389885142933849
  0.22523353218791048
  0.33843373358494755
  0.4690720464707251
  0.6176857089056866
  0.7795187598378055
  0.9428422295179912
  1.1037394204463045
  ⋮
  9.244737893689699
  9.335074241332878
  9.415255247041264
  9.50510551498179
  9.597534392550964
  9.701299487412248
  9.82106560483051
  9.944260157048916
 10.0
u: 90-element Vector{Vector{Float64}}:
 [1.0, 1.0]
 [1.0238868996864754, 0.9170635457394599]
 [1.078916507168589, 0.7840700680684657]
 [1.1684333040439097, 0.648181661797101]
 [1.2957042221353816, 0.5304694456892483]
 [1.4809165029418272, 0.42955350448196544]
 [1.7473500412998426, 0.3493258329777825]
 [2.1153534688823297, 0.29350101163419345]
 [2.5832559744848433, 0.26353346256414195]
 [3.1543771993181107, 0.25762589850470435]
 ⋮
 [1.6314324936468194, 3.864812724469416]
 [1.3472326751123798, 3.3687280315526102]
 [1.1805111281005654, 2.929661941303766]
 [1.0600904013068375, 2.473305341614175]
 [0.9880439518210773, 2.05974515342669]
 [0.9520591757932619, 1.6679455857549355]
 [0.9544664757660277, 1.304800623866498]
 [0.9960186134110254, 1.0163709023183924]
 [1.026344767484708, 0.9096910781836054]

Now let's calculate the sensitivity of the $\ell_2$ error against 1 at evenly spaced points in time, that is:

\[L(u,p,t)=\sum_{i=1}^{n}\frac{\Vert1-u(t_{i},p)\Vert^{2}}{2}\]

for $t_i = 0.5i$. This is the assumption that the data is data[i]=1.0. For this function, notice we have that:

\[\begin{aligned} dg_{1}&=1-u_{1} \\ dg_{2}&=1-u_{2} \\ & \quad \vdots \end{aligned}\]

and thus:

dg(out,u,p,t,i) = (out.=1.0.-u)
dg (generic function with 1 method)

Also, we can omit dgdp, because the cost function doesn't dependent on p. If we had data, we'd just replace 1.0 with data[i]. To get the adjoint sensitivities, call:

ts = 0:0.5:10
res = adjoint_sensitivities(sol,Vern9(),t=ts,dgdu_discrete=dg,abstol=1e-14,
                            reltol=1e-14)
([87.94877760237878, 22.48841733753836], [-25.50065435744049 77.25507872062931 -93.53213951322394])

This is super high accuracy. As always, there's a tradeoff between accuracy and computation time. We can check this almost exactly matches the autodifferentiation and numerical differentiation results:

using ForwardDiff,Calculus,ReverseDiff,Tracker
function G(p)
  tmp_prob = remake(prob,u0=convert.(eltype(p),prob.u0),p=p)
  sol = solve(tmp_prob,Vern9(),abstol=1e-14,reltol=1e-14,saveat=ts,
              sensealg=SensitivityADPassThrough())
  A = convert(Array,sol)
  sum(((1 .- A).^2)./2)
end
G([1.5,1.0,3.0])
res2 = ForwardDiff.gradient(G,[1.5,1.0,3.0])
res3 = Calculus.gradient(G,[1.5,1.0,3.0])
res4 = Tracker.gradient(G,[1.5,1.0,3.0])
res5 = ReverseDiff.gradient(G,[1.5,1.0,3.0])

and see this gives the same values.