pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.96k stars 499 forks source link

Captum for BERT #150

Closed felicitywang closed 4 years ago

felicitywang commented 5 years ago

Hi, Thanks for the great work. The LSTM tutorial looks very nice. Are any suggestions on how to use Captum for Transformer-based / BERT-like pre-trained contextualized word embeddings? If I want to see the attribution of each token in the word embedding layer, is it that I'd also need the FFN layer for fine-tuning downstream tasks in order to get the gradients? The current code is implemented with torch/text; would really appreciate it if you could some hints how to integrate it with BERT models(e.g. huggingface/transformers).

Thank you.

NarineK commented 5 years ago

@felicitywang, thank you for the question. This is something that has high priority on the list. Yes, it will work in combination with downstream tasks. I have to look closer into this but you we will need to compute the gradients of any output that we choose with respect to those pre-trained embedding vectors.

I'll hopefully have a tutorial out for this soon.

We have another unmerged totorial on seqtoseq: https://github.com/pytorch/captum/pull/100/ You might find this helpful too.

vfdev-5 commented 5 years ago

Hi, waiting for the official solution to interpret models from transformers, here is a way to run interpretation as in the example:

import torch
import torch.nn as nn

from transformers import BertTokenizer
from transformers import BertForSequenceClassification, BertConfig

from captum.attr import IntegratedGradients
from captum.attr import InterpretableEmbeddingBase, TokenReferenceBase
from captum.attr import visualization
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# We need to split forward pass into two part: 
# 1) embeddings computation
# 2) classification

def compute_bert_outputs(model_bert, embedding_output, attention_mask=None, head_mask=None):
    if attention_mask is None:
        attention_mask = torch.ones(embedding_output.shape[0], embedding_output.shape[1]).to(embedding_output)

    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

    extended_attention_mask = extended_attention_mask.to(dtype=next(model_bert.parameters()).dtype) # fp16 compatibility
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    if head_mask is not None:
        if head_mask.dim() == 1:
            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            head_mask = head_mask.expand(model_bert.config.num_hidden_layers, -1, -1, -1, -1)
        elif head_mask.dim() == 2:
            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
        head_mask = head_mask.to(dtype=next(model_bert.parameters()).dtype) # switch to fload if need + fp16 compatibility
    else:
        head_mask = [None] * model_bert.config.num_hidden_layers

    encoder_outputs = model_bert.encoder(embedding_output,
                                         extended_attention_mask,
                                         head_mask=head_mask)
    sequence_output = encoder_outputs[0]
    pooled_output = model_bert.pooler(sequence_output)
    outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]  # add hidden_states and attentions if they are here
    return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)    

class BertModelWrapper(nn.Module):

    def __init__(self, model):
        super(BertModelWrapper, self).__init__()
        self.model = model

    def forward(self, embeddings):        
        outputs = compute_bert_outputs(self.model.bert, embeddings)
        pooled_output = outputs[1]
        pooled_output = self.model.dropout(pooled_output)
        logits = self.model.classifier(pooled_output)
        return torch.softmax(logits, dim=1)[:, 1].unsqueeze(1)

bert_model_wrapper = BertModelWrapper(model)
ig = IntegratedGradients(bert_model_wrapper)

# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

def interpret_sentence(model_wrapper, sentence, label=1):

    model_wrapper.eval()
    model_wrapper.zero_grad()

    input_ids = torch.tensor([tokenizer.encode(sentence, add_special_tokens=True)])
    input_embedding = model_wrapper.model.bert.embeddings(input_ids)

    # predict
    pred = model_wrapper(input_embedding).item()
    pred_ind = round(pred)

    # compute attributions and approximation delta using integrated gradients
    attributions_ig, delta = ig.attribute(input_embedding, n_steps=500, return_convergence_delta=True)

    print('pred: ', pred_ind, '(', '%.2f' % pred, ')', ', delta: ', abs(delta))

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].numpy().tolist())    
    add_attributions_to_visualizer(attributions_ig, tokens, pred, pred_ind, label, delta, vis_data_records_ig)

def add_attributions_to_visualizer(attributions, tokens, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            pred_ind,
                            label,
                            "label",
                            attributions.sum(),       
                            tokens[:len(attributions)],
                            delta))    

interpret_sentence(bert_model_wrapper, sentence="text to classify", label=0)
visualization.visualize_text(vis_data_records_ig)

image

@NarineK it would be helpful if you could comment out whether the code is correct about association between tokens and attributions. Thanks :)

HTH

NarineK commented 5 years ago

Looks great, @vfdev-5 ! Thank you!

  1. I do not see the embedding vector for the reference. It looks like currently it is using zero values for all baselines/references but I think that choosing a reference token such as or some other token would be a better baseline. For reference you can look it in our IMDB tutorial.

  2. Another point is the target class that is attributed to the input embeddings. Is it in your case always 1 ? That can be passed as a variable to attribute method with target variable.

  3. When we initialize VisualizationDataRecord and pass tokens[:len(attributions)], we can simply pass tokens. The sizes should match. I think we can fix it in our original IMDB tutorial as well.

mralexis1 commented 5 years ago

Thanks @vfdev-5 for the example and @NarineK for the feedback. A quick question for the first point in @NarineK 's post: in the tutorial, the reference is using the padding token. However, what would the reference token be in BERT related models? Do you mind to provide some instructions on how we could construct baseline reference using BERT? Thanks!

vfdev-5 commented 5 years ago

@mralexisw a reference can be probably defined as

input_ids = tokenizer.encode(sentence, add_special_tokens=True)
t_input_ids = torch.tensor([input_ids, ]).to(device)
t_ref_input_ids = t_input_ids.clone()
t_ref_input_ids[0, 1:-1] = 0

which should be something like

[CLS][PAD]...[PAD][SEP]
NarineK commented 5 years ago

Thank you @vfdev-5 ! Yes, that right. @mralexisw, we can choose reference/baseline tokens, for example PAD, and create a reference/baseline sentence. For example if my sentence is: Have a wonderful day then the reference / baselines could be: PAD PAD PAD PAD We can then encode / numericalize it with tokenizer as @mralexisw mentioned and compute the embeddings of tokenized token ids which can then be passed as an argument to attribute method with baselines=reference_embeddings, ...

felicitywang commented 5 years ago

Thank you @NarineK for the quick reply and thank you @vfdev-5 for the reference code. (I've switched for now to Al2's AllenNLP Interpret which already has support for BERT, if anyone's looking for quicker solutions to this problem. )

NarineK commented 5 years ago

Nice! Sorry for the delay. I'll have clean tutorials by the latest next week.

mralexis1 commented 5 years ago

@NarineK Thanks! If there is no reference embedding, would that be equivalent to a vanilla gradient?

NarineK commented 5 years ago

@mralexisw , it will still compute the integral of gradients along the path from 0 to given input but the attribution might be a little off. It is know that the attribution depends on the choice of baseline and the carefully we chose it, the better results we get.

In the case of saliency, it is taking the gradient for given input point. It won't be the same but you can easily compare by calling: Saliency(model).attribute(input)

Also, in the example above are the weights for the self.model.classifier learnt and loaded properly ? Fine tuning tasks are the ones that tune those weights, aren't they ?

mralexis1 commented 5 years ago

@NarineK Thanks for pointing out the Saliency class! Btw, it would be helpful to have a tutorial on the best way to use captum for BERT/transformer-based models. If time permits, it is also super helpful to have sections on 1) how to extract the raw scores 2) how to use insights (w/ text data only) to do the visualization interactively 👍

@vfdev-5 should have a better idea on self.model.classifier, but I bet this is the weights that were fine-tuned.

NarineK commented 4 years ago

@vfdev-5, @mralexisw, @felicitywang, we've published a tutorial here: https://github.com/pytorch/captum/blob/master/tutorials/Bert_SQUAD_Interpret.ipynb Let me know what do you think!

To be more flexible on working with multiple sub-embedding layers and to be able to interpret all of them simultaneously I still pre-compute embedding layers here similar to previous tutorials but we'll ultimately also have a version that doesn't require to do that and allows to attribute to BertEmbedding.

jchoi92 commented 4 years ago

@NarineK Thanks for the tutorial! Very easy to follow. I was able to easily replicate the process for a classification task using BERT (BertForSequenceClassification) with a few minor changes.

One small issue I ran into is with the forward function of InterpretableEmbeddingBase (currently using install from master). I believe it's because BertModel passes in all arguments as keyword arguments for embeddings (see here) whereas InterpretableEmbeddingBase's forward function expects one positional argument. I'm sure there's a cleaner solution, but for now I had to change the function as below to get it working.

def forward(self, *inputs, **kwargs):
    """
     The forward function of a wrapper embedding layer that takes and returns
     embedding layer. It allows embeddings to be created outside of the model
     and passes them seamlessly to the preceding layers of the model.

     Args:

        input (tensor): Embedding tensor generated using the `self.embedding`
                layer using `other_inputs` and `kwargs` of necessary.
        *other_inputs (Any, optional): A sequence of additional inputs that the
                forward function takes. Since forward functions can take any
                type and number of arguments, this will ensure that we can
                execute the forward pass using interpretable embedding layer
        **kwargs (Any, optional): Similar to `other_inputs` we want to make sure
                that our forward pass supports arbitrary number and type of
                key-value arguments

     Returns:

        tensor:
        Returns output tensor which is the same as input tensor.
        It passes embedding tensors to lower layers without any
        modifications.
    """
    return kwargs["inputs_embeds"]

Separately, an issue I keep running into is GPU memory usage. I'm on a single Nvidia Tesla V100 (16GB) which has no problems finetuning the model (using a maximum sequence length of 128 and batch size of 32) and similarly has no issues with inference.

For integrated gradients, once I pass in a larger text sample (e.g., 30 - 40 tokens), I immediately run into memory issues. The same happens if I run it on shorter text multiple times. Do you have any pointers to what's driving this increased memory usage and ideas on how to optimize? I'm hoping to run it on the entire training set, which seems infeasible right now.

NarineK commented 4 years ago

@jchoi92 thank you very much for the feedback! Yeah, I understand what you mean. I installed transformers through pip install and in that package they don't pass input embeddings as positional args. Let me see what I can do to cover that case as well.

With respect to memory issues: Some things that you can do is:

  1. Reducing the number of integral approximation steps. Default is (n_steps=50)
  2. Setting internal_batch_size to a smaller number - In order to approximate integral we expand the input batch size from input_batch_size to n_steps * input_batch_size. This is the main reason why we hit memory issues. To solve that problem we can provide internal_batch_size= '<A small number>' that will allow us to chunk that large tensor to smaller pieces and run the forward/backward passes on it and ultimately aggregating the results.
  3. If you had multiple GPUs you could also wrap your models with DataParallel. In that case we will distribute the computations across multiple GPUs.

Let me know if this helps.

Thank you!

NarineK commented 4 years ago

@jchoi92, this PR: https://github.com/pytorch/captum/pull/211 should fix the issue with named arguments for the embedding layers. Let me know if you run into any issues.

armheb commented 4 years ago

Hi, trying to reproduce the SQUAD BERT tutorial I get the following error trying to run the: `input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(input_ids, ref_input_ids, \ token_type_ids=token_type_ids, ref_token_type_ids=ref_token_type_ids, \ position_ids=position_ids, ref_position_ids=ref_position_ids)

start_scores, end_scores = predict(input_embeddings, \ token_type_ids=token_type_ids, \ position_ids=position_ids, \ attention_mask=attention_mask)` Here is the error: The size of tensor a (90) must match the size of tensor b (36) at non-singleton dimension 1

NarineK commented 4 years ago

Thank you for trying out the tutorial, @armheb ! Yeah, I think that they made some changes in Bert and there were some inconsistencies. I have a new PR (#222) open where I made some updates. Do you mind trying the version in #222 PR ?

Also, how did you fine-tune Bert model? Did you use bert-base-uncased or bert-base-cased model ? I believe that originally I fine-tuned with bert-base-cased however currently they made some changes and it's not fine-tuning with bert-base-cased anymore. If you fine-tune with bert-base-uncased you might see slightly different results. I intend to update the notebooks for the uncase version.

armheb commented 4 years ago

Thanks for your quick response, I used bert-large-uncased-whole-word-masking-finetuned-squad which they have finetuned on squad. I'll try the new PR and let you know the results.

armheb commented 4 years ago

Thank you so much, it's fixed now.

mralexis1 commented 4 years ago

Thanks, @NarineK !

mralexis1 commented 4 years ago

On a separate note, I found that captum is super (GPU) memory-intensive. I was not able to run the code using a 12GB mem GPU (the exact same code works on CPU and 32GB mem GPU). It would be super helpful to specify the minimum GPU memory required for large models like BERT/ResNet/etc.

NarineK commented 4 years ago

Thank you for the feedback @mralexisw ! It depends on what algorithms and what parametrization you use. The tutorial that we have on Bert, runs on CPU under 2 - 3 mins. In general IG can be memory intensive depending on the integral approximation steps.

mralexis1 commented 4 years ago

@NarineK Another couple of quick questions for the BERT visualization part: 1) How should we interpret Attribution Score? Does it have any specific meaning? I didn't find any references in the paper. 2) What's the difference between Target Label and Attribution Label?

NarineK commented 4 years ago
  1. The magnitude of the attribution score shows the strength / the level of the importance of a feature for a particular selected class that we want to attribute to (aka target). If positive, it means that the feature is positively contributing to particular class (e.g. pulling towards the class that we are attributing to), if negative, it means that it is negatively contributing to the class that we are trying to attributing to (e.g. pulling away from the target class, it is probable that it is pulling towards another target / class but no guarantees) If zero - means that the feature doesn't contribute to selected target class

  2. That's a good point I think that the naming is a little confusing. I think what I wanted to is have Prediction Label and Attribution aka Target Label

e.g. if we predict that something on the image is a dog with a high probability then dog is the predicted class. Now, we can attribute our output to dog, cat or anything else that we want to and that is the attribution aka target class / label. I'll fix the header. Thank you!

Does it make sense ?

mralexis1 commented 4 years ago

That makes sense, @NarineK . A bit further for point 1, why should we use Frobenius Norm for attribution scores?

For point 2, a renaming/docstring would be super helpful.

One more thing I noticed for the tutorial: I added print(inputs, inputs.shape) in func squad_pos_forward_func. The print-out is:

tensor([[ 101, 1184, 1110, 1696, 1106, 1366,  136,  102, 1122, 1110, 1696, 1106,
         1366, 1106, 1511,  117, 9712, 9447, 1105, 1619, 3612, 1104, 1155, 7553,
          119,  102]]) torch.Size([1, 26])
tensor([[101,   0,   0,   0,   0,   0,   0, 102,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 102]]) torch.Size([1, 26])
tensor([[ 101, 1184, 1110,  ..., 7553,  119,  102],
        [ 101, 1184, 1110,  ..., 7553,  119,  102],
        [ 101, 1184, 1110,  ..., 7553,  119,  102],
        ...,
        [ 101, 1184, 1110,  ..., 7553,  119,  102],
        [ 101, 1184, 1110,  ..., 7553,  119,  102],
        [ 101, 1184, 1110,  ..., 7553,  119,  102]]) torch.Size([50, 26])
tensor([[101,   0,   0,   0,   0,   0,   0, 102,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 102]]) torch.Size([1, 26])
tensor([[ 101, 1184, 1110, 1696, 1106, 1366,  136,  102, 1122, 1110, 1696, 1106,
         1366, 1106, 1511,  117, 9712, 9447, 1105, 1619, 3612, 1104, 1155, 7553,
          119,  102]]) torch.Size([1, 26])
tensor([[ 101, 1184, 1110, 1696, 1106, 1366,  136,  102, 1122, 1110, 1696, 1106,
         1366, 1106, 1511,  117, 9712, 9447, 1105, 1619, 3612, 1104, 1155, 7553,
          119,  102]]) torch.Size([1, 26])
tensor([[101,   0,   0,   0,   0,   0,   0, 102,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 102]]) torch.Size([1, 26])
tensor([[ 101, 1184, 1110,  ..., 7553,  119,  102],
        [ 101, 1184, 1110,  ..., 7553,  119,  102],
        [ 101, 1184, 1110,  ..., 7553,  119,  102],
        ...,
        [ 101, 1184, 1110,  ..., 7553,  119,  102],
        [ 101, 1184, 1110,  ..., 7553,  119,  102],
        [ 101, 1184, 1110,  ..., 7553,  119,  102]]) torch.Size([50, 26])

It makes sense for the first two. However, do you have any idea why at some point the shape is [50, 26]? Upon checking the 50 arrays seem to be identical.

Thanks again for your help!

NarineK commented 4 years ago

@mralexisw, that's a good point! 1) Frobenius Norm is just an example, you can choose another normalization method. In fact, it is challenging to find a good normalization method for the attributions. You can think of normalizing on the example/instance level or on the dataset or batch level use L1 or L-infinity norm or something else. 2) In this case I assume you're using LayerIntegratedGradients ?
This is a trick that we did for LayerIntegratedGradients that helps us to avoid monkey patching That [50 x 26] matrix is being passed as an argument however we override the output of the embedding output in the hook: https://github.com/pytorch/captum/blob/master/captum/attr/_core/layer/layer_integrated_gradients.py#L274 If you use monkey-patching with configure_interpretable_embedding_layer you wouldn't see it and the output would be more clear to you. LayerIntegratedGradients does some tricks under the hood. Let me know if this makes sense.

NarineK commented 4 years ago

Do you, guys, still want to keep this open or can we close it ?

vfdev-5 commented 4 years ago

@NarineK It OK for me to close the issue.

NarineK commented 4 years ago

Awesome! Thank you! Closing for now! Feel free to open a new issue if you'll have more questions.

efatmae commented 4 years ago

Can you pleaase provide a full example on how to use Captum with BERTSetenceClassification? I've to run the code provided here. It is running without errors the results don't make sense to me.

p16i commented 4 years ago

I'm also interested in doing that. Would you mind sharing a bit how your result looks like?

efatmae commented 4 years ago

@heytitle, this is results: Screenshot from 2020-07-29 09-33-03

the comment is "This is bullshit, i hate you" is labelled as "bullying" but attribution label is the comment not the label and the label is included in the colored word importance. I don't understand why.

p16i commented 4 years ago

If I'm not mistaken, the arguments for VisualizationDataRecord are (in order)

  1. attribution score
  2. predicted prob
  3. true label; ⚠️for this argument, you specify torch.argmax, which doesn't look like the true label.
  4. predicted label
  5. attribution label; ⚠️ in your example, you specify text for this argument`; this is why you have the comment there. Noting here that, this setting is just for visualization; ideally, the value should come directly from when you setup the IG method.
  6. total attribution score
  7. delta (error from IG method due to number of interpolation steps)

Regarding label being included in the colored word importance, what do you mean by this? Are you referring to [CLS] or [SEP]?

efatmae commented 4 years ago

thanks @heytitle for the clarification. The comment "This is bullshit, i hate you" is labelled as "bullying" and I set text_comment, label = "This is bullshit, I hate you", "bullying"

but whne i do the visualization i find the word "bullying" between [sep] bullying [sep]

I don't understand why?

p16i commented 4 years ago

I see. I didn't pay attention to [SEP] bullying [SEP]. I thought they're part of the comment.

Could it be that all_tokens contians those three tokens, which, from what you just described, seems to be the true label of the text? I guess we have to check how all_tokens is constructed.

efatmae commented 4 years ago

I shared the jupyter notebook on google Colab https://colab.research.google.com/drive/1wC6Z5eCs4SnZo6RFlTGUvYlIcuQa82WK?usp=sharing

Thanks

p16i commented 4 years ago

I don't see any content in the notebook. Is there anything I should do in order to see the content?

efatmae commented 4 years ago

There is something wrong with my Google Drive. can you try this link https://colab.research.google.com/drive/1gzOOKplSCAVTXagUwfv68ivwebYjKJFC?usp=sharing

efatmae commented 4 years ago

@vfdev-5 i tried your example but when i tried the baselines you suggested above input_ids = tokenizer.encode(sentence, add_special_tokens=True) t_input_ids = torch.tensor([input_ids, ]).to(device) t_ref_input_ids = t_input_ids.clone()

attributions_ig, delta = ig.attribute(input_embedding, baselines=t_ref_input_ids, n_steps=500,return_convergence_delta=True)

I got this error

"RuntimeError: The size of tensor a (768) must match the size of tensor b (6) at non-singleton dimension 2"

efatmae commented 4 years ago

Hi all, i found this colab notebook https://colab.research.google.com/drive/1pgAbzUF2SzF0BdFtGpJbZPWUOhFxT2NZ#scrollTo=X-nyyq_tbUDa

with BERT binary classification example wiht BERT. I found it in this forum https://github.com/pytorch/captum/issues/311#issuecomment-612460705

I have not tried it yet but it looks promising.

RylanSchaeffer commented 3 years ago

@efatmae did you find a solution? I too am having the error: The expanded size of the tensor (768) must match the existing size (53) at non-singleton dimension 2. Target sizes: [50, 53, 768]. Tensor sizes: [1, 53]