iancovert / fastshap

An amortized approach for calculating local Shapley value explanations
MIT License
86 stars 17 forks source link

Feature: SHAP Interaction FX #5

Open szvsw opened 1 year ago

szvsw commented 1 year ago

In Consistent Individualized Feature Attribution for Tree Ensembles, Lundberg et al define SHAP interaction effects:

Given $M$ features, for $i\neq j$,

\phi_{i,j} = \sum_{S\subseteq N\setminus \{i,j\} }c(S)\nabla_{i,j}(S)

where the contribution term is given by

\nabla_{i,j}(S) = f_x(S\cup \{i,j\}) - f_x(S\cup\{i\}) - f_x(S\cup\{j\}) + f_x(S) \\

and the weighting term is given by

c(s)=\frac{|S|!(M-|S|-2)!}{2(M-1)!}

For the diagonal $i=j$, set $\phi_{i,i}=\phii - \sum{j\neq i}\phi_{i,j}$, where $\phi_i$ is the usual SHAP value.

I'm interested in adapting the FastSHAP loop to generate interaction effects, but I'm unsure how involved the process would be. There are a few approaches I could imagine. In either approach, the new explainer could internally just generate the upper triangular portion of the matrix and then automatically build the fully symmetric matrix.

  1. Use an existing FastSHAP model as the basis
    1. ensure that the new model's predictions rows/columns sum to the correct shap value (using the existing SHAP model)
    2. create a new loss term which enforces the new \nabla_{i,j} term for the off diagonals by sampling random pairs of $i\neq j$. These can be sampled uniformly I think.
  2. Train the explainer to predict both SHAP and shap interactions simultaneously

The tricky part is obviously thinking about how to construct the interaction loss term(s). My intuition is that it is something in the ballpark of the following pseudocode:

# for simplicity in indexing, I'm writing this as if x is a single sample rather than a batch
# and pred is a plain shap interaction matrix rather than a tensor

pred = interaction_explainer(x)
shap_vals = fastshap(x)

# the pred matrix must sum to the same total as the true shap values
shap_sum_loss = mse_loss(shap_vals.sum(),pred.sum())

# the columns and rows of the matrix must sum to the same total as the shap values
shap_col_loss = mse_loss(shap_vals, pred.sum(dim=0)
shap_row_loss = mse_loss(shap_vals, pred.sum(dim=1)

# symmetry loss
shap_sym_loss = mse_loss(pred, pred.T)

# interaction loss
a,b= sample_ij()
S = sample_excluding_ij(a,b) # assume S is just a single coalition for now, not a batch
null = impute(x,{})
baseline = impute(x,S)
baseline_with_i = impute(x, S union {a})
baseline_with_j = impute(x, S union {a})
complete = impute(x, S union {a,b})
ifx_S = S @ pred # something like this to pick out all cells where both i,j are in S
ifx_S_i = (S union {a})  @ pred # something like this to pick out all cells where both i,j in S union {a})
ifx_S_j = = (S union {b})  @ pred # something like this to pick out all cells where both i,j in S union {b})
ifx_S_ij = = (S union {a,b})  @ pred # something like this to pick out all cells where both i,j in S union {a,b})
loss_S = mse_loss(null + ifx_S.sum(), baseline)
loss_S_i = mse_loss(null+ifx_S_i.sum(), baseline_with_i)
loss_S_j = mse_loss(null+ifx_S_j.sum(), baseline_with_j)
loss_S_ij = mse_loss(null+ifx_S_ij.sum(), complete)

There might of course be some complexities which I totally miss out on in terms of how i and j are sampled, how S is sampled, etc etc.

Just wanted to get your thoughts on the level of complexity from a math perspective of implementing this... i.e. does it need a full paper's worth of work to prove out the right sampling schemes, loss term, etc, or is it something that should be trivial to figure out?

iancovert commented 1 year ago

I haven't thought about this before and am not an expert on interaction values, but the first equation above makes me think you could train a FastSHAP-like interaction explainer by 1) predicting all the interaction values, and 2) penalizing them via a squared error loss for the deviation from sampled $\nabla_{i,j}(S)$ terms. If you penalize using the squared error loss to samples from a distribution, the model is encouraged to predict that distribution's mean, so the theory may be as clean as that.

I didn't read the above code too closely, is that similar to what you implemented?

Not to say it's straightforward, I could see this being the subject of its own paper. I suspect it would be difficult to predict $d^2$ interaction terms rather than $d$ SHAP values, important to generate training targets on-the-fly as efficiently as possible, and potentially tricky to validate the quality of the learned explainer.