pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.8k stars 485 forks source link

Understand the captum output for bert #853

Open SaidaSaad opened 2 years ago

SaidaSaad commented 2 years ago

Hello

I am new to Captum and also Bert, I have Bert model that is take as input sequence of letters of max length 2170 and output sequence of the same length that predict label for each letter, i have only three labels[0,1,2] Input : aaaaddc ----> output:0101220, I would like to the attribution for each letter in the sequence to predict specific class. is it possible I can get the attribution for every letter in the seq to get the right prediction for all . or it is only give attribution for every letter to only predict one position for only one letter. can it work for batch of samples ? I tried something and it works but i do not understand what i get and how it means or there is any other thing that i can do using captum to know more information ?

""" text = ["A D C C C A A A"] inputs = tokenizer(text, return_tensors="pt", padding=True) input_ids = inputs["input_ids"] token_type_ids = None position_ids = None attention_mask = inputs["attention_mask"] start_scores = predict(input_ids, \ token_type_ids=token_type_ids, \ position_ids=position_ids, \ attention_mask=attention_mask) lig = LayerIntegratedGradients(squad_pos_forward_func, model.bert.embeddings) attributions_start, delta_start = lig.attribute(inputs=input_ids, additional_forward_args=(token_type_ids, position_ids, attention_mask, 7), return_convergence_delta=True)

def summarize_attributions(attributions): attributions = attributions.sum(dim=-1).squeeze(0) attributions = attributions / torch.norm(attributions) return attributions

attributions_start_sum = summarize_attributions(attributions_start)

all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) ground_truth_start_ind = 0 start_position_vis = viz.VisualizationDataRecord( attributions_start_sum, torch.max(torch.softmax(start_scores[0], dim=0)), torch.argmax(start_scores), torch.argmax(start_scores), str(ground_truth_start_ind), attributions_start_sum.sum(), all_tokens, delta_start)

html = viz.visualize_text([start_position_vis])

with open("data.html", "w") as file: file.write(html.data) """ I do not understand what is in the image , is that attribution for every letter to predict for letter of position 7 with right prediction of class 0 .(ground_truth_start_ind = 0) , Is that right?

44

NarineK commented 2 years ago

@SaidaSaad, is index 7 the output class index that you want to attribute to the inputs ?

visualize_text is a simple auxiliary function that you don't have to use for the visualization if you have specific ways of visualizing the attributions. visualize_text is an example of visualizations that might not always be useful for all possible applications. It simply color codes each attribution score for each input token. In this case, a letter in the inputs. I don't understand what ground_truth_start_ind is in your case. It will visualize whatever you set but it is meant to be the ground truth label.

https://github.com/pytorch/captum/blob/master/captum/attr/_utils/visualization.py#L450

NarineK commented 1 year ago

@SaidaSaad, is this still an issue or can we close it ?