iancovert / fastshap

An amortized approach for calculating local Shapley value explanations
MIT License
86 stars 17 forks source link

How to decide loss_fn when we build regression models? #7

Closed nonameisagoodname closed 1 month ago

nonameisagoodname commented 7 months ago

I have observed that the loss function KLDivLoss() is utilized in training the surrogate model. However, in the TensorFlow scenario, the loss function employed is categorical_crossentropy. I am now wondering which loss function to utilize when incorporating FastSHAP when we build regression models. Additionally, are there any other significant changes we should consider?

iancovert commented 7 months ago

That's a good question, we never tested FastSHAP for regression models but I can explain how I would do it. Recall that the two steps for our approach are 1) training a surrogate model and 2) training the explainer model - I'll describe below how each step needs to be modified.

  1. Surrogate model. For classification models, our goal is to make the surrogate classifier output $\mathbb{E}[f(x) \mid x_S]$, which can be achieved by penalizing the KL divergence between the full-input and randomly masked surrogate predictions. (I'm more familiar with the PyTorch implementation, but it's possible the TensorFlow one uses a cross entropy loss with soft labels - this is the same as the KL divergence up to a constant.) Anyway, for a regression model we can just swap the loss function for a simple MSE loss. This approach to training the surrogate is discussed in this paper that came before ours, and it's the right thing to do for regression tasks because it encourages the surrogate to output $\mathbb{E}[f(x) \mid x_S]$ when the predictions are real-valued ($f(x) \in \mathbb{R}$) rather than probabilities ($f(x) \in [0, 1]$).

  2. Explainer model. For this step, the explainer model is trained with the same weighted least squares loss. One small difference is that instead of outputting estimated attributions for each class, there is just one attribution per feature, so there is no summation in the loss across classes.

Let me know if that makes sense. If you end up writing code to try this out, I'm happy to take a look and see if it matches up with what I described here.

agorji commented 1 month ago

Hi @iancovert

Thank you for sharing your insights on employing FastSHAP for a regression model. Your suggested approach of training a regressor surrogate model and using it as the value function makes sense. However, there is one part of the algorithm that remains ambiguous to me in the regression setting. In the training algorithm for the Explainer, the value function conditioned on all possible labels (categories) is computed and added to the loss for each sample input $x \sim p(x)$:

image

I am still trying to figure out where this is done in fastshap.py to understand the implications when the surrogate model is modified as you suggested. However, from a fundamental perspective, shouldn't this part of the algorithm also be adapted for regression cases? I can imagine ideas such as replacing the iteration over labels with sampling from the label distribution, but I reckon this might be a heuristic that requires further analysis.

agorji commented 1 month ago

Hi again,

I guess I realized what am I missing after rereading the paper! $v_{x, y}(\mathbf{s})$ is basically the value function for the $y^{th}$ class of the output, not the value function conditioned on both model's input and output. When using a surrogate model as the value function, feeding it with $\mathbf{x}_s$ results in the output vector, and the mentioned iteration is done algebraically. Therefore, in the case of a classic one-dimensional regression surrogate model, it should be all good and nothing needs to be changed!

iancovert commented 1 month ago

Hi, yes that’s exactly right! Training FastSHAP for classifiers basically means calculating the squared error loss for each class and then summing, so for regression models with a single output dimension, you just calculate the loss once and no summation is necessary.