cdpierse / transformers-interpret

Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.
Apache License 2.0
1.27k stars 96 forks source link

MultiLabelSequenceClassificationExplainer potentially bugged. #107

Closed rowanvanhoof closed 1 year ago

rowanvanhoof commented 1 year ago

Before I describe the issue, I must state that data science is complicated and I could be misunderstanding your code. I see that the multi label explainer calls a binomial explainer for the number of labels you have. I recognize that some implementations of multilabel models use the approach of simply training n binomial classifiers and weighing each of their predictions. If that is the thought process when creating this package than I believe that I my issue below is a misunderstanding of how the code should be used. In that case, I wasn't able to trace through the code and figure out if there is any further processing of the n outputs from the softmax functions.

I have looked into the source code here and it looks like the multi label explainer is actually a multi class explainer. Multi label problems involve data where one instance can have 0, 1, or more labels (i.e. a photo of dog can be labeled as a dog, an animal, and a mammal), while multiclass problems involve data where one instance can have a single label, which can be many different labels. The primary difference is that multiclass has mutually exclusive labels which requires a different activation function than multilabel.

The multilabel explainer uses a softmax function to process the logits from the classifier. This is correct for multiclass, because softmax creates a probability distribution from the logits, meaning that only one label can be highly confident (probability over .5). Multilabel problems require a different activation function, like the sigmoid function, which converts logits into probabilities that do not necessarily sum to 1, so you can have several labels that the model is highly confident in.

Additionally, I noticed that for the multi label explainer, the prediction is made once for each label, however multilabel classifiers give the logits for all labels in one prediction. Calling predict with the same input 6 times in a row produces 6 identical sets of logits, so I am wondering if this is redundant. I am aware that some other explainers like eli5 have to do this to be able to accurately compute the attributions for several classes, but wasn't sure if the repetition here is for the same reason.

I modified the explainer to run 6 times using the sigmoid function built in to pytorch instead of the softmax function, and it works as expected. If this is a bug I can submit a pull request with the changes that I have made.

thomasgirault commented 1 year ago

Hi @rowanvanhoof , I think you are right. The code was designed for a multi class problem and not for multilabels. I don't work on this repository but I would be really interested in your code (pull request, fork ?).

SimonLinnebjerg commented 1 year ago

Hi @thomasgirault, i also agree that the multi-label model should be implemented with sigmoid instead of softmax on the output. I forked the repo and made these changes. You can have a look and a clone/fork if youre interested.

I also made a minor change to the ordering of labels (as i think theres a bug with these too). I believe the ordering of the labels should follow the ids in the label2id dict. without sorting based on the id, the labels are swapped in the output html file generated with the visualize() function

Feel free to write me if you have questions or pointers :)

thomasgirault commented 1 year ago

Thank you @SimonLinnebjerg, I will try your code today.

SimonLinnebjerg commented 1 year ago

Cool. but beware. The changes i made will make all SequenceClassificationExplainer use sigmoid instead of softmax. meaning that it only works for multi-label classification and not multi-class.

thomasgirault commented 1 year ago

Hi @SimonLinnebjerg , I tried your code and your contribution with sigmoid is working for me. In order to deal with bouth multi-label and multi-class, we could add a condition on problem_type such as :

if self.model.config.problem_type == "multi_label_classification":
    self.pred_probs = torch.sigmoid(preds)[0][self.selected_index]
    return torch.sigmoid(preds)[:, self.selected_index]
self.pred_probs = torch.softmax(preds, dim=1)[0][self.selected_index]
return torch.softmax(preds, dim=1)[:, self.selected_index]
SimonLinnebjerg commented 1 year ago

@thomasgirault good suggestion. But the problem_type is not required to set in a transformer model, (at least I dont set it) im not sure if we should rely on it.

thomasgirault commented 1 year ago

Also, the execution can be really slow with even with a dozen of labels. Maybe we could limit the explanations for labels having a score >= 0.5 ?

SimonLinnebjerg commented 1 year ago

I think a proper solution would be to create a new class called MultiClassClassificationExplainer and pass in the softmax function to the SequenceClassificationExplainer. Same for the current MultiLabelClassificationExplainer but with sigmoid passed in.

With regards to the execution speed, its not really an issue for me as i dont have many labels.

For my small usecase i think i will just leave it as it is for now. At some point i might find the time to implement the stuff correctly and do a proper PR.

cdpierse commented 1 year ago

Hi @thomasgirault, @SimonLinnebjerg, and @rowanvanhoof, thanks very much for taking a look at all of this, it's been on my todo list for quite some time and I haven't been able to find the time so I really appreciate you guys working on it. From what I can gather it seems I did make a mistake with the activation function for multilabel, I think part of the reason I let it pass by was likely because it doesn't affect the attribution scores so much as the displayed probabilities. But it is definitely worth fixing.

@SimonLinnebjerg I took a look at your solution and it seems to be great, I think we don't need to even worry about creating a new class, to me the simplest solution is to make the MultiLabelClassificationExplainer implement its own _forward method rather than inheriting that of the SequenceClassificationExplainer this on top of the changes you've made to the label order should do the trick.

With regard to the speed, it is an unfortunate part of this type of explanation and attribution method, it requires attribution for one class at a time, therefore we need those attributions calculated class by class. If you have access to a GPU or Google Colab with GPU runtime this does greatly improve the speed.

Thanks again to both of you for all the work and discussion you are doing around this. Hugely appreciated.

SimonLinnebjerg commented 1 year ago

Made the simplest changes i was able to come up with in PR #120

I was not able to only implement _forward function as per @cdpierse suggestion

cdpierse commented 1 year ago

Hi @SimonLinnebjerg, I saw your PR thanks for doing that, if possible though I would like to avoid the addition of an additional explainer and changes to the sequence classification explainer. I have a branch that seems to be working for me where I implement a custom _forward method and monkey patch it in place of the SequenceClassificationExplainer's _forward method.

What do you think of this? Monkey patching is a little hacky, but it does keep the changes minimal.

Below is a screenshot of the output of my branches changes:

Screenshot 2022-12-16 at 14 54 25
SimonLinnebjerg commented 1 year ago

@cdpierse i think that is great and seems to work exactly the way i would want it.

I then suppose the case where you do multi class classification, you should use the SequenceClassificationExplainer class ? If so, i would just like to point out, that the SequenceClassificationExplainer visualize only outputs explanations for the highest probability class, and not the lower probability classes. Not that there is anything wrong with that, i suppose it is just a design choise.

Great stuff! :thumbsup:

cdpierse commented 1 year ago

Hi finally got around to publishing a release for this to pypi with version 0.10.0, bit late but better late than never 😬

SimonLinnebjerg commented 1 year ago

Hi finally got around to publishing a release for this to pypi with version 0.10.0, bit late but better late than never 😬

Awesome!