pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.87k stars 490 forks source link

All green/all red word importance visualisation #520

Closed robinvanschaik closed 3 years ago

robinvanschaik commented 3 years ago

Hi all,

I have fine-tuned a 9-class text classifier by making use of Flair and made use of these multilingual sentence-embeddings. I am getting good performance, but I would like to incorporate model transparency for debugging & explainability towards other stakeholders.

I have managed to create a wrapper for my Flair model & reworked the forward method in order to utilize the LayerIntegratedGradients method of Captum. Similar to the BERT tutorial that is provided by Captum. However, I recognize I may have Frankenstein'd this.

In my Forward method I return the logits as suggested here and calculate the probability outside the forward function.

I am getting close to an output, but visualising word importance feature is throwing me off. It seems in my case to return either all green / all red word importances. I have also noticed that the attribution score can change immensely if I crank up the number of steps from 50, to 200 to 7000.

Screenshot 2020-11-10 at 19 12 53

To calculate the attributions across word I make use of the following function:

def summarize_attributions(attributions):
    "Helper function for calculating word attributions."
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

But I am unsure if I missed out on anything regarding calculating the attributions or other obvious mistakes.

For transparency's sake, here is the code I have so far.


from flair.models import TextClassifier
from flair.data import Sentence
import torch
import torch.nn as nn

from transformers import AutoTokenizer

from captum.attr import (
    InterpretableEmbeddingBase,
    IntegratedGradients,
    LayerIntegratedGradients,
    TokenReferenceBase,
    configure_interpretable_embedding_layer,
    remove_interpretable_embedding_layer
)

from captum.attr import visualization as viz

# Load our pre-trained text-classifier.
model_path = "./best-model.pt"

# Flair class model
flair_model = TextClassifier.load(model_path)
# Actual Roberta model. 
model = flair_model.document_embeddings.model

class FlairModelWrapper(nn.Module):

    def __init__(self, text_model, layers: str = "-1"):
        super(FlairModelWrapper, self).__init__()
        self.flair_model = flair_model
        # Shorthand for the actual PyTorch model.
        self.model = flair_model.document_embeddings.model

        # Split the name to automatically grab the right tokenizer.
        self.model_name = text_model.document_embeddings.get_names()[0].split('transformer-document-')[-1]
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        self.label_dictionary = self.flair_model.label_dictionary

        if torch.cuda.is_available():
            self.device = torch.device("cuda:0")
        else:
            self.device = torch.device("cpu")

        self.num_classes = len(self.flair_model.label_dictionary)
        self.initial_cls_token: bool = False

        if layers == 'all':
            # send mini-token through to check how many layers the model has
            hidden_states = self.model(torch.tensor([1], device=device).unsqueeze(0))[-1]
            self.layer_indexes = [int(x) for x in range(len(hidden_states))]
        else:
            self.layer_indexes = [int(x) for x in layers.split(",")]

        self.initial_cls_token: bool = False

    def forward(self, input_ids):
        # Run the input embeddings through all the layers.
        hidden_states = self.model(input_ids=input_ids)[-1]

        # BERT has an initial CLS token.
        # Meaning that the the first token contains the classification.
        # Other models have this as the top layer.
        index_of_CLS_token = 0 if self.initial_cls_token else input_ids.shape[1] -1

        cls_embeddings_all_layers = \
            [hidden_states[layer][0][index_of_CLS_token] for layer in self.layer_indexes]

        output_embeddings = torch.cat(cls_embeddings_all_layers)

        # https://github.com/pytorch/captum/issues/355#issuecomment-619610044
        # Sigmoid might lead to very small probability scores and it's better to attribute the logits to the inputs.
        # We call the decoder that is used by Flair. (nn.linear)
        label_scores = self.flair_model.decoder(output_embeddings)

        # Captum expects [#examples, #classes] as dimensions in order to use target indices.
        label_scores_resized = torch.reshape(label_scores, (1, self.num_classes))

        return label_scores_resized

# Initiate wrapper & Captum Integrated to calculate attributions.
flair_model_wrapper = FlairModelWrapper(flair_model)

# Shorthand for embeddings layer in wrapper.
wrapper_embeddings = flair_model_wrapper.flair_model.document_embeddings.model.embeddings
lig = LayerIntegratedGradients(flair_model_wrapper, wrapper_embeddings)

# Empty list to append the results.
vis_data_records_ig = []

def interpret_sentence(flair_model_wrapper, sentence, label=1):
    """
    We can visualise the attributions made by making use of Pytorch Captum.
    Inputs:
    flair_model_wrapper: the wrapper around the Flair model to enable Captum.
    sentence: the sentence we want to check.
    label: the ground truth class-id of the sentence.
    """

    # In order maintain consistency with Flair, we apply the same tokenization
    # steps.
    # Needs to be in a list for the visualiser.
    tokenized_sentence = [sentence.to_tokenized_string()]

    # This calculates the token input IDs tensor for the model.
    input_ids = flair_model_wrapper.tokenizer.encode(tokenized_sentence[0],
                                                     add_special_tokens=False,
                                                     max_length=flair_model_wrapper.tokenizer.model_max_length,
                                                     truncation=True,
                                                     return_tensors="pt")

    # The input IDs are passed to the embedding layer of the model.

    # It is better to return the logits for Captum.
    # https://github.com/pytorch/captum/issues/355#issuecomment-619610044
    # Thus we calculate the softmax afterwards.
    # For now, I take the first dimension and run this sentence, per sentence.
    softmax = torch.nn.functional.softmax(flair_model_wrapper(input_ids)[0], dim=0)
    # Return the confidence and the class ID of the top predicted class.
    conf, idx = torch.max(softmax, 0)

    attributions_ig, delta = lig.attribute(input_ids,
                                           n_steps=7000,
                                           #method = "riemann_trapezoid",
                                           return_convergence_delta=True,
                                           target=label)

    print('pred: ', idx.item(), '(', '%.2f' % conf.item(), ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(word_attributions=attributions_ig,
                                   tokens=tokenized_sentence,
                                   pred_prob=conf.item(),
                                   pred_ind=idx.item(),
                                   true_class=label,
                                   delta=delta,
                                   vis_data_records=vis_data_records_ig)

def summarize_attributions(attributions):
    "Helper function for calculating word attributions."
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

def add_attributions_to_visualizer(word_attributions, tokens, pred_prob, pred_ind, true_class, delta, vis_data_records):
    attributions = summarize_attributions(word_attributions)

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(
        viz.VisualizationDataRecord(word_attributions=attributions,
                                    pred_prob=pred_prob,
                                    pred_class=pred_ind,
                                    true_class=true_class,
                                    attr_class=true_class,
                                    attr_score=attributions.sum(),
                                    raw_input=tokens,
                                    convergence_score=delta))

flair_model_wrapper.label_dictionary.get_item_for_index(0)

sentence_to_classify = Sentence("Hallo wereld. Waar is mijn order? Ik heb tot nu toe nog niks ontvangen?")

interpret_sentence(flair_model_wrapper, sentence_to_classify, label=0)
viz.visualize_text(vis_data_records_ig)
robinvanschaik commented 3 years ago

Hah. I finally found my mistake. The sanity checks in this colab helped me.

sentence = Sentence("Hallo wereld. Waar is mijn order? Ik heb tot nu toe nog niks ontvangen?")

tokenized_sentence = [sentence.to_tokenized_string()]

# Returns 
# ['Hallo wereld . Waar is mijn order ? Ik heb tot nu toe nog niks ontvangen ?']

I would pass this sentence into the raw_input parameter of VisualizationDataRecord.

However, this is just a a list containing one element. Passing the actual tokens via:

all_tokens = flair_model_wrapper.tokenizer.convert_ids_to_tokens(input_ids[0])
# Returns the following:
['▁Hallo',
 '▁wereld',
 '▁',
 '.',
 '▁Waar',
 '▁is',
 '▁mijn',
 '▁order',
 '▁?',
 '▁Ik',
 '▁heb',
 '▁tot',
 '▁nu',
 '▁toe',
 '▁nog',
 '▁niks',
 '▁ontvangen',
 '▁?']
Screenshot 2020-11-12 at 13 17 39