pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.74k stars 477 forks source link

Issue with images input for Integrated Gradient for image captioning attribution score computation #1188

Open alimirgh75 opened 10 months ago

alimirgh75 commented 10 months ago

Hi all, I have an issue regarding inputs to attribute method of integrated gradient algorithm. I am using the GIT model for image captioning and defined the forward function to return one token_id of the caption at a time. The input of the model is the (processed_image , processed_caption) where "processed_caption" is the sequence of previously generated token_ids.

from transformers import AutoProcessor, AutoTokenizer
from transformers import AutoModelForCausalLM
import torch
from captum.attr import IntegratedGradients
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F

output_dir = "microsoft/git-base"
processor = AutoProcessor.from_pretrained(output_dir)
model = AutoModelForCausalLM.from_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(output_dir)
model.eval()
class MultimodalModel(torch.nn.Module):
    def __init__(self, cap_model, processor):
        super().__init__()
        self.cap_model = cap_model
        self.processor = processor

    def forward(self, input_image,input_text):            
        generated_ids = self.cap_model.generate(pixel_values=input_image, input_ids=input_text, max_length=1)     

        return generated_ids[0, -1].item()
# Initialize the multimodal model
multimodal_model = MultimodalModel(model, processor)

# Initialize the Integrated Gradients attribution algorithm
ig = IntegratedGradients(multimodal_model)
img_path = test_df_images[0]
image = Image.open(img_path)
img = image.resize((224, 224), resample=Image.BILINEAR)
pixel_values = processor(images=img, return_tensors="pt").pixel_values

generated_caption = prediction_captions[0]
# Tokenize the target caption (correction)
caption_ids = processor(generated_caption, add_special_tokens=True, return_tensors='pt').input_ids
print(caption_ids)
# Iterate over the length of the original tensor
for i in range(caption_ids.shape[1]-1):
    # Token ID for [CLS]
    cls_token_id = 101
    # Create a subsequence up to the i-th element
    subsequence = caption_ids[:, :i + 1]
    #subsequence = subsequence.to(torch.long)
    target_subsequence = caption_ids[:, :i + 2]
    target = target_subsequence[0, -1].item()
    # Compute the attribution scores for the current word
    (img_attr,text_attr), _ = ig.attribute(inputs=(pixel_values , subsequence), baselines=(torch.zeros_like(pixel_values ),torch.zeros_like(subsequence)), target=target, return_convergence_delta=True)

    print(subsequence)

I am getting this error on ig.attribute

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)
nischitpra commented 1 month ago

Hey @alimirgh75 , did you find a solution? I'm having the same issue

nischitpra commented 1 month ago

Not sure if its relevant but I got it working with LayerIntegratedGradients. Posting here if people come across the issue

Heres a basic example

import torch
from transformers import BertTokenizer, BertForSequenceClassification
from captum.attr import IntegratedGradients, LayerIntegratedGradients

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.eval()

text = "hello world"
encoded = tokenizer(text, return_tensors="pt")

def predict(inputs, attention_mask=None):
    return model(inputs, attention_mask=attention_mask).logits

lig = LayerIntegratedGradients(predict, model.bert.embeddings)
attributions_start, delta_start = lig.attribute(
    inputs=encoded['input_ids'],
    target=torch.tensor([0]),
    additional_forward_args=encoded['attention_mask'],
    return_convergence_delta=True
)
print(attributions_start, delta_start)