Using the Shapley method in case of correlated inputs

One of the primary drawbacks of typical global sensitivity analysis methods is their inability to handle correlated inputs. The Shapley method is one of the few methods that can handle correlated inputs. The Shapley method is a game-theoretic approach that is based on the idea of marginal contributions of each input to the output.

It has gained extensive popularity in the field of machine learning and is used to explain the predictions of black-box models. Here we will use the Shapley method on a Scientific Machine Learning (SciML) model to understand the impact of each parameter on the output.

We will use a Neural ODE trained on a simulated dataset from the Spiral ODE model. The Neural ODE is trained to predict output at a given time. The Neural ODE is trained using the SciML ecosystem.

As the first step let's generate the dataset.

using GlobalSensitivity, OrdinaryDiffEq, Flux, SciMLSensitivity, LinearAlgebra
using Optimization, OptimizationOptimisers, Distributions, Copulas, CairoMakie

u0 = [2.0f0; 0.0f0]
datasize = 30
tspan = (0.0f0, 1.5f0)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1f0 2.0f0; -2.0f0 -0.1f0]
    du .= ((u .^ 3)'true_A)'
end
t = range(tspan[1], tspan[2], length = datasize)
prob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat = t))
2×30 Matrix{Float32}:
 2.0  1.9465    1.74178  1.23837  0.577126  …  1.40688   1.37023   1.29215
 0.0  0.798832  1.46473  1.80877  1.86465      0.451358  0.728681  0.972087

Now we will define our Neural Network for the dynamics of the system. We will use a 2-layer neural network with 10 hidden units in the first layer and the second layer. We will use the Chain function from Flux to define our NN. A detailed tutorial on is available here.

dudt2 = Flux.Chain(x -> x .^ 3,
    Flux.Dense(2, 10, tanh),
    Flux.Dense(10, 2))
p, re = Flux.destructure(dudt2) # use this p as the initial condition!
dudt(u, p, t) = re(p)(u) # need to restrcture for backprop!
prob = ODEProblem(dudt, u0, tspan)

θ = [u0; p] # the parameter vector to optimize

function predict_n_ode(θ)
    Array(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:end], saveat = t))
end

function loss_n_ode(θ)
    pred = predict_n_ode(θ)
    loss = sum(abs2, ode_data .- pred)
    loss
end

loss_n_ode(θ)

callback = function (state, l) #callback function to observe training
    display(l)
    return false
end

# Display the ODE with the initial parameter values.
callback(θ, loss_n_ode(θ))

# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((p, _) -> loss_n_ode(p), adtype)
optprob = Optimization.OptimizationProblem(optf, θ)

result_neuralode = Optimization.solve(optprob,
    OptimizationOptimisers.Adam(0.05),
    callback = callback,
    maxiters = 300)
retcode: Default
u: 54-element Vector{Float32}:
  1.9673309
  0.72916913
 -2.931701
 -0.55005187
  1.3469751
  3.0278177
  0.15054308
  0.49979755
 -0.07660793
 -1.7933686
  ⋮
  0.98326826
 -0.61498237
 -2.0834146
  1.6609653
  1.3511711
 -0.72405964
  0.37361556
 -0.7407782
 -0.4889444

Now we will use the Shapley method to understand the impact of each parameter on the resultant of the cost function. We will use the Shapley function from GlobalSensitivity to compute the so called Shapley Effects. We will first have to define some distributions for the parameters. We will use the standard Normal distribution for all the parameters.

First let's assume no correlation between the parameters. Hence the covariance matrix is passed as the identity matrix.

d = length(θ)
mu = zeros(Float32, d)
#covariance matrix for the copula
Covmat = Matrix(1.0f0 * I, d, d)
#the marginal distributions for each parameter
marginals = [Normal(mu[i]) for i in 1:d]

copula = GaussianCopula(Covmat)
input_distribution = SklarDist(copula, marginals)

function batched_loss_n_ode(θ)
    # The copula returns samples of `Float64`s
    θ = convert(AbstractArray{Float32}, θ)
    prob_func(prob, i, repeat) = remake(prob; u0 = θ[1:2, i], p = θ[3:end, i])
    ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
    sol = solve(
        ensemble_prob, Tsit5(), EnsembleThreads(); saveat = t, trajectories = size(θ, 2))
    out = zeros(size(θ, 2))
    for i in 1:size(θ, 2)
        out[i] = sum(abs2, ode_data .- sol[i])
    end
    return out
end

shapley_effects = gsa(
    batched_loss_n_ode, Shapley(; n_perms = 100, n_var = 100, n_outer = 10),
    input_distribution, batch = true)
GlobalSensitivity.ShapleyResult{Vector{Float64}, Vector{Float64}}([0.10620003015851633, 0.08077083693959107, -0.045121896924576695, 0.007617263021308582, -0.0677841791656659, 0.04325062551937232, -0.03339448992434615, -0.016198444983910373, 0.03843724905108222, 0.07414971941393096  …  0.028490642322988684, 0.06368304033418594, 0.004391883911206462, -0.005080123418526766, 0.020667293115705174, -0.05395890396694116, 0.07776667105926739, -0.09736249424206762, 0.0255355371818342, 0.025982676255445463], [0.3602885879185995, 0.35478182743820236, 0.3157144337117648, 0.32345017863062453, 0.2724277530484824, 0.3542422923833581, 0.3675301320868929, 0.3832933204925164, 0.49134133900574406, 0.46257784165740656  …  0.4300685700778562, 0.4225030410914641, 0.3466337422049775, 0.46070918735311117, 0.3432205178743452, 0.36163179044132193, 0.37825255306450034, 0.3217916925310428, 0.31056618761162946, 0.39503680633729515], [-0.5999656021619387, -0.6146015448392855, -0.6639221869996357, -0.6263450870947155, -0.6017425751406914, -0.6510642675520095, -0.7537535488146562, -0.7674533531492425, -0.9245917754001761, -0.8325028502345858  …  -0.8144437550296094, -0.7644229202050836, -0.6750102508105494, -0.9080701306306247, -0.6520449219180114, -0.7627572132319321, -0.6636083329471533, -0.7280742116029115, -0.5831741905369595, -0.748289464165653], [0.8123656624789714, 0.7761432187184677, 0.5736783931504823, 0.6415796131373327, 0.46617421680935955, 0.7375655185907541, 0.6869645689659639, 0.7350564631814218, 1.0014662735023405, 0.9808022890624478  …  0.8714250396755868, 0.8917890008734555, 0.6837940186329624, 0.8979098837935711, 0.6933795081494217, 0.6548394052980498, 0.819141675065688, 0.5333492231187762, 0.6342452649006279, 0.800254816676544])
barplot(
    1:54, shapley_effects.shapley_effects;
    color = :green,
    figure = (; size = (600, 400)),
    axis = (;
        xlabel = "parameters",
        xticklabelrotation = 1,
        xticks = (1:54, ["θ$i" for i in 1:54]),
        ylabel = "Shapley Indices",
        limits = (nothing, (0.0, 0.2))
    )
)
Example block output

Now let's assume some correlation between the parameters. We will use a correlation of 0.09 between all the parameters.

Corrmat = fill(0.09f0, d, d)
for i in 1:d
    Corrmat[i, i] = 1.0f0
end

#since the marginals are standard normal the covariance matrix and correlation matrix are the same
copula = GaussianCopula(Corrmat)
input_distribution = SklarDist(copula, marginals)
shapley_effects = gsa(
    batched_loss_n_ode, Shapley(; n_perms = 100, n_var = 100, n_outer = 100),
    input_distribution, batch = true)
GlobalSensitivity.ShapleyResult{Vector{Float64}, Vector{Float64}}([0.038437136597802686, 0.06969236061152527, 0.020432404326990682, 0.009561685339617878, 0.026805833587766406, 0.013245625464923343, 0.024259842275680496, 0.02286993984851298, 0.011253892109549303, 0.04807885726317191  …  0.017089246280725254, -0.003522804398626468, -0.005101911405281731, 0.011595323930388632, 0.04690571706736815, 0.03516189606910672, 0.051292109817962526, 0.015209759023975407, 0.03680408090116813, 0.011451610836473097], [0.13193062625555094, 0.141302343677217, 0.14784159493953594, 0.13418884846489373, 0.127553868350345, 0.1407559204794725, 0.11330764437004431, 0.15524851652274577, 0.12252686170697907, 0.16207757266734588  …  0.12926930672097478, 0.11721527738373595, 0.1065521134470214, 0.11313400602464212, 0.1326359827316077, 0.1459708878263884, 0.13868769910828316, 0.12081434111358248, 0.14168153971096198, 0.12576451831378688], [-0.22014689086307715, -0.20726023299582003, -0.26933712175449975, -0.2534484576515738, -0.22319974837890977, -0.26263597867484273, -0.19782314068960638, -0.2814171525360687, -0.22889875683612967, -0.269593185164826  …  -0.2362785948923853, -0.23326474807074893, -0.21394405376144368, -0.2101473278779099, -0.21306080908658295, -0.25094104407061457, -0.22053578043427244, -0.22158634955864626, -0.24089173693231733, -0.2350468450585492], [0.2970211640586825, 0.3466449542188706, 0.3102019304084811, 0.27257182833080956, 0.27681141555444255, 0.28912722960468945, 0.24634282524096734, 0.32715703223309467, 0.2514065410552283, 0.3657508996911698  …  0.27045708745383584, 0.22621913927349602, 0.2037402309508802, 0.2333379757386872, 0.30687224322131923, 0.32126483620882795, 0.3231200000701975, 0.25200586760659704, 0.31449989873465356, 0.2579500667314954])
barplot(
    1:54, shapley_effects.shapley_effects;
    color = :green,
    figure = (; size = (600, 400)),
    axis = (;
        xlabel = "parameters",
        xticklabelrotation = 1,
        xticks = (1:54, ["θ$i" for i in 1:54]),
        ylabel = "Shapley Indices",
        limits = (nothing, (0.0, 0.2))
    )
)
Example block output