pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.78k stars 483 forks source link

How to improve interpretability for multi-label classification BERT models? #1027

Open mgh1 opened 1 year ago

mgh1 commented 1 year ago

❓ Questions and Help

Our Use-Case

Thanks to the authors/contributors for such an amazing project! Our use-case is to use Captum to help visualize the word importance for our BERT-based multi-label classification model.

Our Problem

The issue we are experiencing is that we can only make this work well for single-label classification models. For multi-label, we are not getting a good result. E.g., irrelevant words are being highlighted as important (unlike in single-label).

Our Model

Our BERT model is fine-tuned on over a million records and there are 125 classes. For each record, there is at least one class.

Our Understanding

To level set our understanding on how to use Captum for multi-label classification, we need an attribute on each class individually using target attribute and a corresponding target index.

Our Evaluation and Findings

To evaluate, we trained both a multi-label and single-label classification BERT model and used Captum to compare the visualization for the same class between both models.

But we found that, for the same single class, the visualization of multi-label is significantly worse than the single-label model. For example, the multi-label model will apply colors on irrelevant words, even including punctuation. But the single label model visualization pays more attention to the key words, and provides a much more intuitive highlighting.

Help Needed

Is it possible to achieve a similar level of high performance of the visualization for the multi-label case as we are experiencing in the single-label case? Any tips or guidance would be much appreciated.

NarineK commented 1 year ago

@mgh1, as you mentioned, with Captum you can call attribute for each target task. Do you have the colab notbook that you can share with us and we can take a look? The visualization should still be meaningful both for single and multi-task use cases.