Open szvsw opened 1 year ago
I haven't thought about this before and am not an expert on interaction values, but the first equation above makes me think you could train a FastSHAP-like interaction explainer by 1) predicting all the interaction values, and 2) penalizing them via a squared error loss for the deviation from sampled $\nabla_{i,j}(S)$ terms. If you penalize using the squared error loss to samples from a distribution, the model is encouraged to predict that distribution's mean, so the theory may be as clean as that.
I didn't read the above code too closely, is that similar to what you implemented?
Not to say it's straightforward, I could see this being the subject of its own paper. I suspect it would be difficult to predict $d^2$ interaction terms rather than $d$ SHAP values, important to generate training targets on-the-fly as efficiently as possible, and potentially tricky to validate the quality of the learned explainer.
In Consistent Individualized Feature Attribution for Tree Ensembles, Lundberg et al define SHAP interaction effects:
Given $M$ features, for $i\neq j$,
where the contribution term is given by
and the weighting term is given by
For the diagonal $i=j$, set $\phi_{i,i}=\phii - \sum{j\neq i}\phi_{i,j}$, where $\phi_i$ is the usual SHAP value.
I'm interested in adapting the FastSHAP loop to generate interaction effects, but I'm unsure how involved the process would be. There are a few approaches I could imagine. In either approach, the new explainer could internally just generate the upper triangular portion of the matrix and then automatically build the fully symmetric matrix.
\nabla_{i,j}
term for the off diagonals by sampling random pairs of $i\neq j$. These can be sampled uniformly I think.The tricky part is obviously thinking about how to construct the interaction loss term(s). My intuition is that it is something in the ballpark of the following pseudocode:
There might of course be some complexities which I totally miss out on in terms of how i and j are sampled, how S is sampled, etc etc.
Just wanted to get your thoughts on the level of complexity from a math perspective of implementing this... i.e. does it need a full paper's worth of work to prove out the right sampling schemes, loss term, etc, or is it something that should be trivial to figure out?