pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.94k stars 499 forks source link

Help with BERT IntegratedGradients #1112

Open marc-gav opened 1 year ago

marc-gav commented 1 year ago

I have been trying to use Captum to extract attribuitions from BERT. I have read the documentations and the examples and watch this conference: https://www.youtube.com/watch?v=0QLrRyLndFI. But I still can't seem to understand how to use Integrated gradients.

I have written some reproducible code that breaks here attribution = ig.attribute(input_indices, target=[101]):

import transformers
import torch
from captum.attr import (
    IntegratedGradients,
    visualization,
    configure_interpretable_embedding_layer,
    remove_interpretable_embedding_layer,
)

transformers.logging.set_verbosity_error()

bert_tokenizer = transformers.BertTokenizerFast.from_pretrained(
    "bert-base-uncased"
)
bert_model = transformers.BertModel.from_pretrained("bert-base-uncased")
bert_model.eval()
bert_model.zero_grad()

text = "The cat sat on the mat."
input_indices = bert_tokenizer.encode(text, return_tensors="pt")
print(input_indices)
interpretable_emb = configure_interpretable_embedding_layer(
    bert_model, "embeddings.word_embeddings"
)

input_emb = interpretable_emb.indices_to_embeddings(input_indices)
ig = IntegratedGradients(bert_model)
attribution = ig.attribute(input_indices, target=[101])
print(attribution)
remove_interpretable_embedding_layer(bert_model, interpretable_emb)

Thanks in advance for your help. Once I understand how to work with this. I would be more than happy to add some examples for BERT in the documentation with the different attribuition scores!

aobo-y commented 1 year ago

@marc-gav input_indices is the word token IDs, which are integers. They have no gradients. They will be mapped into embeddings in Bert, and this operation is not differentiable. You cannot use IntegratedGradients while your input has no gradients. You may want to check other algorithms or apply LayerIntegratedGradients directly to the embedding layers.

Btw, we have tutorials about Bert https://captum.ai/tutorials/Bert_SQUAD_Interpret And actually Bert does not matter. It is the same as using other text model https://captum.ai/tutorials/IMDB_TorchText_Interpret

marc-gav commented 1 year ago

Thank you for your time and your response! I recently realized that I made a mistake in the following line: attribution = ig.attribute(input_indices, target=[101]).

Although I successfully obtained the embeddings for my input_indices using input_emb = interpretable_emb.indices_to_embeddings(input_indices), I was not utilizing these embeddings in the ig.attribute function.

Here's my updated code:

import transformers
import torch
from captum.attr import (
    IntegratedGradients,
    visualization,
    configure_interpretable_embedding_layer,
    remove_interpretable_embedding_layer,
)

transformers.logging.set_verbosity_error()

bert_tokenizer = transformers.BertTokenizerFast.from_pretrained(
    "bert-base-uncased"
)
bert_model = transformers.BertModel.from_pretrained("bert-base-uncased")
bert_model.eval()
bert_model.zero_grad()

text = "The cat sat on the mat."
input_indices = bert_tokenizer.encode(text, return_tensors="pt")
interpretable_emb = configure_interpretable_embedding_layer(
    bert_model, "embeddings"
)

input_emb = interpretable_emb.indices_to_embeddings(input_indices)
ig = IntegratedGradients(bert_model)
attribution = ig.attribute(input_emb, target=101)
print(attribution)
remove_interpretable_embedding_layer(bert_model, interpretable_emb)

I am attempting to utilize the function configure_interpretable_embedding_layer by replicating the provided example in the source code (link below) but using Bert instead. However, I am encountering errors with Bert due to the expected shapes of some of the tensors. Could you tell me if you notice any issue with my code?

https://github.com/pytorch/captum/blob/b8eff98aaf0b17ff4d57a339cec5d3fba250e006/captum/attr/_models/base.py#L160-L182

marc-gav commented 1 year ago

After some digging, the issue I am facing is that by using the wrapper InterpretableEmbeddingBase, my input is a 3-dimensional tensor instead of the 2-dimensional input required for normal Bert execution.

In Bert, there is a step: batch_size, seq_length = input_shape where the input needs to be 2-dimensional. This makes sense as the input is always [batch_size, 512] (512 being the max number of input tokens, it can be smaller than 512 but for clarity I've decided to write 512). However, feeding an embedding tensor of shape [batch_size, 512, 768] cause Bert's logic to break.

As you suggested, I have decided to use LayerIntegratedGradients. Furthermore, I will also use BertForMaskedLM instead of the regular Bert model. I aim to perform a classification task targeting a specific token prediction to compute my gradient. I am still curious how would someone use captum.attr.configure_interpretable_embedding_layer with BertForMaskedLM because I haven't managed to do it.

Omar-Emam-99 commented 11 months ago

You need to update the **kargs for the predict forward, instead of input_ids use inputs_embeds


def predict_forward_func(input_ids, token_type_ids=None, 
                         position_ids=None, attention_mask=None):
    """Function passed to ig constructors"""
    return model(inputs_embeds=input_ids, 
                 token_type_ids=token_type_ids, 
                 position_ids=position_ids, 
                 attention_mask=attention_mask)[0]  ```