cdpierse / transformers-interpret

Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code.
Apache License 2.0
1.27k stars 96 forks source link

MutliClass True Labels: Bug Fix #148

Open elemets opened 3 months ago

elemets commented 3 months ago

When visualising the explainability results of a multiclass model I can't work out if there is a way to display the true values correctly. When setting the "true_class" variable in the .visualize() function of the explainer class, it sets every value of True Label to this rather than setting each individual one. Screenshot 2024-06-11 at 10 03 33 AM

I'm assuming this behaviour is because it was designed for binary classification. I can see inside the .visualize() it would be easy to add this as a behaviour: true_class needs to use the index: i in the multi class case. I edited this to make it work.

    def visualize(self, html_filepath: str = None, true_class: str = None):
        """
        Visualizes word attributions. If in a notebook table will be displayed inline.

        Otherwise pass a valid path to `html_filepath` and the visualization will be saved
        as a html file.

        If the true class is known for the text that can be passed to `true_class`

        """
        tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]

        score_viz = [
            self.attributions[i].visualize_attributions(  # type: ignore
                self.pred_probs_list[i],
                "",  # including a predicted class name does not make sense for this explainer
                (
                    "n/a" if not true_class else true_class[i]
                ),  # no true class name for this explainer by default
                self.labels[i],
                tokens,
            )
            for i in range(len(self.attributions))
        ]

These are the results I now get: Screenshot 2024-06-11 at 10 17 24 AM

I hope this is useful or if this solution seems fine we can integrate it. When passing an array to true_class it needs to be wrapped in a list() otherwise it throw an error to do with truth values of numpy arrays.