pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.85k stars 489 forks source link

Issues with BERT-type model #794

Open nataliebarcickikas opened 2 years ago

nataliebarcickikas commented 2 years ago

I am working with the LayoutLMv2 model in huggingface (https://huggingface.co/transformers/model_doc/layoutlm.html). Works fine with performing a forward pass, but get a dimensionality error related to the embeddings when I try to use it in Captum for explainability. Note that LayoutLM (first version of the model) gives no issues in the same context. Also, I realize that this model needs to be finetuned. This is just supposed to be a proof-of-concept usage.

Here is my code:


from PIL import Image, ImageDraw, ImageFont
from transformers import LayoutLMv2FeatureExtractor, LayoutLMv2TokenizerFast, LayoutLMv2Processor, LayoutLMv2ForSequenceClassification
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper
from captum.attr import LayerIntegratedGradients, TokenReferenceBase
import torch
import torchvision
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_rgb = Image.open("IMAGE.jpg").convert("RGB")

processor = LayoutLMv2Processor.from_pretrained('microsoft/layoutlmv2-base-uncased')
model = LayoutLMv2ForSequenceClassification.from_pretrained('microsoft/layoutlmv2-base-uncased')
tokenizer = LayoutLMv2TokenizerFast.from_pretrained("microsoft/layoutlmv2-base-uncased")

encoding = processor(image_rgb, return_tensors="pt")

input_ids = encoding['input_ids']
token_type_ids = encoding['token_type_ids']
attention_mask = encoding['attention_mask']
bbox = encoding['bbox']

model_layered = ModelInputWrapper(model)

outputs = model_layered(**encoding)

pred, answer_idx = F.softmax(outputs.logits, dim=1).data.cpu().max(dim=1)

def batch_predict(input_ids, image, bbox, attention_mask, token_type_ids):
    model_layered.eval()
    outputs = model_layered(input_ids=input_ids,
                            image=image,
                            bbox=bbox,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
    logits = outputs.logits
    probs = F.softmax(logits, dim=1)
    return probs

attr = LayerIntegratedGradients(batch_predict, 
                                [model_layered.module.layoutlmv2.embeddings.word_embeddings, 
                                 model_layered.module.layoutlmv2.embeddings.position_embeddings, 
                                 model_layered.module.layoutlmv2.embeddings.x_position_embeddings,
                                 model_layered.module.layoutlmv2.embeddings.y_position_embeddings,
                                 model_layered.module.layoutlmv2.embeddings.h_position_embeddings, 
                                 model_layered.module.layoutlmv2.embeddings.w_position_embeddings,
                                 model_layered.module.layoutlmv2.embeddings.token_type_embeddings,])

# Generate reference for tokens
token_reference = TokenReferenceBase(reference_token_idx=tokenizer.pad_token_id)
text_reference_indices = token_reference.generate_reference(len(encoding['input_ids'][0]), device=device).unsqueeze(0)
baselines = text_reference_indices

attributions = attr.attribute(inputs=encoding['input_ids'],
                              additional_forward_args=(encoding['image'], 
                                                       encoding['bbox'], 
                                                       encoding['attention_mask'],
                                                       encoding['token_type_ids']),
                            baselines=baselines,
                            target=answer_idx,
                            n_steps=5)

And the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-00bdf4f97de3> in <module>()
     60                             baselines=baselines,
     61                             target=answer_idx,
---> 62                             n_steps=5)

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/log/__init__.py in wrapper(*args, **kwargs)
     33             @wraps(func)
     34             def wrapper(*args, **kwargs):
---> 35                 return func(*args, **kwargs)
     36 
     37             return wrapper

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/attr/_core/layer/layer_integrated_gradients.py in attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta, attribute_to_layer_input)
    496             method=method,
    497             internal_batch_size=internal_batch_size,
--> 498             return_convergence_delta=False,
    499         )
    500 

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/attr/_core/integrated_gradients.py in attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
    290                 additional_forward_args=additional_forward_args,
    291                 n_steps=n_steps,
--> 292                 method=method,
    293             )
    294 

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/attr/_core/integrated_gradients.py in _attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, step_sizes_and_alphas)
    353             inputs=scaled_features_tpl,
    354             target_ind=expanded_target,
--> 355             additional_forward_args=input_additional_args,
    356         )
    357 

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/attr/_core/layer/layer_integrated_gradients.py in gradient_func(forward_fn, inputs, target_ind, additional_forward_args)
    464 
    465                     output = _run_forward(
--> 466                         self.forward_func, tuple(), target_ind, additional_forward_args
    467                     )
    468                 finally:

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args)
    451         *(*inputs, *additional_forward_args)
    452         if additional_forward_args is not None
--> 453         else inputs
    454     )
    455     return _select_targets(output, target)

<ipython-input-1-00bdf4f97de3> in batch_predict(input_ids, image, bbox, attention_mask, token_type_ids)
     34                             bbox=bbox,
     35                             attention_mask=attention_mask,
---> 36                             token_type_ids=token_type_ids)
     37     logits = outputs.logits
     38     probs = F.softmax(logits, dim=1)

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/captum/attr/_utils/input_layer_wrapper.py in forward(self, *args, **kwargs)
     74             kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name])
     75 
---> 76         return self.module(*tuple(args), **kwargs)

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py in forward(self, input_ids, bbox, image, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)
   1053             output_attentions=output_attentions,
   1054             output_hidden_states=output_hidden_states,
-> 1055             return_dict=return_dict,
   1056         )
   1057         if input_ids is not None:

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py in forward(self, input_ids, bbox, image, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
    893             token_type_ids=token_type_ids,
    894             position_ids=position_ids,
--> 895             inputs_embeds=inputs_embeds,
    896         )
    897 

/home/natbarkas/anaconda3/envs/explainability-v2/lib/python3.6/site-packages/transformers/models/layoutlmv2/modeling_layoutlmv2.py in _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids, inputs_embeds)
    754         token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)
    755 
--> 756         embeddings = inputs_embeds + position_embeddings + spatial_position_embeddings + token_type_embeddings
    757         embeddings = self.embeddings.LayerNorm(embeddings)
    758         embeddings = self.embeddings.dropout(embeddings)

RuntimeError: The size of tensor a (44) must match the size of tensor b (49) at non-singleton dimension 1

These are the versions of the packages I am using: transformers==4.11.2
captum=0.4.0 torch==1.7.0
torchvision==0.8.1

99warriors commented 2 years ago

Hi @nataliebarcickikas - it seems like the error fails during the forward call to batch_predict. What happens if you call batch_predict directly, without using LayerIntegratedGradients?

nataliebarcickikas commented 2 years ago

Directly calling batch_predict causes no issues:

batch_predict(encoding['input_ids'], 
              encoding['image'], 
              encoding['bbox'], 
              encoding['attention_mask'], 
              encoding['token_type_ids'])

Output: tensor([[0.4553, 0.5447]], grad_fn=<SoftmaxBackward>)

99warriors commented 2 years ago

Could you print out the dimensions of all the embeddings in that line 756 for the forward call to batch_predict as well as to attr?

nataliebarcickikas commented 2 years ago

I print the dimensions along with the four individual components of it:

In the forward call:

batch_predict(encoding['input_ids'], 
              encoding['image'], 
              encoding['bbox'], 
              encoding['attention_mask'], 
              encoding['token_type_ids'])
Input embeddings:  torch.Size([1, 44, 768])
Position embeddings:  torch.Size([1, 44, 768])
Spatial position embeddings:  torch.Size([1, 44, 768])
Token type embeddings:  torch.Size([1, 44, 768])
torch.Size([1, 44, 768])
tensor([[0.4942, 0.5058]], grad_fn=<SoftmaxBackward>)

In the attr call:

attributions = attr.attribute(inputs=encoding['input_ids'],
                              additional_forward_args=(encoding['image'], 
                                                       encoding['bbox'], 
                                                       encoding['attention_mask'],
                                                       encoding['token_type_ids']),
                            baselines=baselines,
                            target=answer_idx,
                            n_steps=1)
Input embeddings:  torch.Size([1, 44, 768])
Position embeddings:  torch.Size([1, 44, 768])
Spatial position embeddings:  torch.Size([1, 44, 768])
Token type embeddings:  torch.Size([1, 44, 768])
torch.Size([1, 44, 768])
Input embeddings:  torch.Size([1, 44, 768])
Position embeddings:  torch.Size([1, 44, 768])
Spatial position embeddings:  torch.Size([1, 44, 768])
Token type embeddings:  torch.Size([1, 44, 768])
torch.Size([1, 44, 768])
Input embeddings:  torch.Size([1, 44, 768])
Position embeddings:  torch.Size([1, 49, 768])
Spatial position embeddings:  torch.Size([1, 49, 768])
Token type embeddings:  torch.Size([1, 44, 768])
And then the runtime error occurs:
RuntimeError: The size of tensor a (44) must match the size of tensor b (49) at non-singleton dimension 1
chrisdoyleIE commented 2 years ago

(I left this comment originally on #904)

I've this exact same issue with a custom Bert based model and I've traced it back to the Captum hook being called in line 1072 of the source code for torch.nn.Module (i.e. during the forward call of you model). The hook being called that causes this issue is layer_integrated_gradients.layer_forward_hook. It appears that the cached value in scattered_inputs_dict is being returned no matter what because the hook is being called at the start of the wrapper module's forward method, and not being reset mid call if weights are shared.

# num_current_tokens = 50, num_prev_tokens = 71
input_ids.shape   # torch.Size([1, 50])
self.word_embeddings(input_ids).shape  # torch.Size([1, 71, 768]) instead of torch.Size([1, 50, 768])

For context, my model shares weights and we call the same model twice within forward():

def _forward(...): 
    ....
    outputs_text = self.bert(input_ids=input_ids_text, attention_mask=attention_mask_text, **kwargs)
    outputs_context = self.bert(input_ids=input_ids_context, attention_mask=attention_mask_context, **kwargs)
    ...
    return outputs
chrisdoyleIE commented 2 years ago

@99warriors do you guys have a time estimate? If not I'm happy to fork and go from a starting point

99warriors commented 2 years ago

@chrisdoyleIE Thank you for investigating this. We have had discussions over how to fix this problem (perhaps expand scattered_inputs_dict to cache the result of multiple forward calls of the same module), but this discussion is still on-going. Any suggestions you had would be most helpful / welcome!

For now, we would recommend, as a work-around, to avoid calling the same module multiple times within the same forward pass (and instead creating copies of the module, which can all share weights), and adding a warning related to this is an immediate task we can tackle.

chrisdoyleIE commented 2 years ago

That'll do, many thanks!

phiwi commented 2 years ago

Is there any update regarding this topic?

I have a similar issue: I call the embedding function in my model multiple times to split an input into several chunks (it's a hierarchical model9 where I have to get [CLS]/[PAD]/[SEP] embeddings in between:

...
sep_embed = self.bert.embeddings(torch.tensor([[4]], dtype=torch.long, device=self.device))[0][0]
pad_embed = self.bert.embeddings(torch.tensor([[0]], dtype=torch.long, device=self.device))[0][0]
cls_embed =self.bert.embeddings(torch.tensor([[3]], dtype=torch.long, device=self.device))[0][0]
...

Then, of course, I get wrong shapes from the scattered_inputs_dic (or saved_layer?).