pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.81k stars 488 forks source link

LayerIntegratedGradients hook changes output dimensions for XLNet word_embeddings layer #434

Closed kh8fb closed 4 years ago

kh8fb commented 4 years ago

I am trying to use HuggingFace's XLNet word_embeddinglayer in LayerIntegratedGradients to calculate attributions. The method I use is similar to the BERT Question Answering Tutorial, but I am instead trying to use the embedding layer of XLNet. I get an incredibly long (and misleading) stack trace with the error that Einsum has the incorrect dimensions (see minimum working example below). I looked further into the issue and altered the XLNet script to print the dimensions after every internal forward method.

The issue comes when the gradients are batched to n_steps and passed to the model. The inputs_layer and baselines_layer function calls both work fine (tensors sized [2, 20]). During the gradient_func, when the tensor of size [100, 20] or [n_steps*btc_size, num_ids] is passed to the word_embedding layer, it returns as a tensor of size [1000, 2, 768]. The expected return shape (from the embeddings) for a tensor sized[100, 20] is [20, 100, 768]. I believe that the issue is possibly in the hook that is created for this layer. Are there any workarounds that you guys are aware of to make this work? Any advice would be greatly appreciated.

Minimum working example

from captum.attr import LayerIntegratedGradients
from transformers import XLNetForSequenceClassification, XLNetTokenizer
import torch

model = XLNetForSequenceClassification.from_pretrained("xlnet-base-cased")
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")

sentence = "This is most definitely not a good movie and worse than most of their other moves."                                                                                               
sentence2 = "But this is definitely a fantastic movie and way better than most of their animated movies." 

features = tokenizer([sentence, sentence2], return_tensors='pt', padding=True, truncation=True, max_length=512)

input_ids = features["input_ids"] # Size: [2, 20]
token_type_ids = features["token_type_ids"] # Size: [2, 20]
attention_mask = features["attention_mask"] # Size: [2, 20]
baseline_ids = torch.zeros(input_ids.shape, dtype=torch.int64) # Size [2, 20]

def sequence_forward_func(inputs, model, tok_type_ids, att_mask):
    """Passes forward the inputs and relevant keyword arguments."""
    outputs = model(inputs, token_type_ids=tok_type_ids, attention_mask=att_mask)
    return outputs

lig = LayerIntegratedGradients(sequence_forward_func, model.transformer.word_embedding)

attrs = lig.attribute(inputs=input_ids,
                     baselines=baseline_ids,
                     additional_forward_args=(model, token_type_ids, attention_mask),
                     n_steps=50,
                     target=0,
                     return_convergence_delta=False)

Error stack trace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-901481a09904> in <module>
      1 lig = LayerIntegratedGradients(sequence_forward_func, model.transformer.word_embedding)
      2 
----> 3 attrs = lig.attribute(inputs=input_ids,
      4                                      baselines=baseline_ids,
      5                                      additional_forward_args=(model, token_type_ids, attention_mask),

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/captum/attr/_core/layer/layer_integrated_gradients.py in attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta, attribute_to_layer_input)
    350             else inps
    351         )
--> 352         attributions = self.ig.attribute(
    353             inputs_layer,
    354             baselines=baselines_layer,

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/captum/attr/_core/integrated_gradients.py in attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta)
    276 
    277         # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
--> 278         grads = _batched_operator(
    279             self.gradient_func,
    280             scaled_features_tpl,

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/captum/attr/_utils/batching.py in _batched_operator(operator, inputs, additional_forward_args, target_ind, internal_batch_size, **kwargs)
    154     of the results of each batch.
    155     """
--> 156     all_outputs = [
    157         operator(
    158             inputs=input,

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/captum/attr/_utils/batching.py in <listcomp>(.0)
    155     """
    156     all_outputs = [
--> 157         operator(
    158             inputs=input,
    159             additional_forward_args=additional,

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/captum/attr/_core/layer/layer_integrated_gradients.py in gradient_func(forward_fn, inputs, target_ind, additional_forward_args)
    331                     hook = self.layer.register_forward_hook(layer_forward_hook)
    332 
--> 333                 output = _run_forward(
    334                     self.forward_func, tuple(), target_ind, additional_forward_args
    335                 )

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/captum/attr/_utils/common.py in _run_forward(forward_func, inputs, target, additional_forward_args)
    498     additional_forward_args = _format_additional_forward_args(additional_forward_args)
    499 
--> 500     output = forward_func(
    501         *(*inputs, *additional_forward_args)
    502         if additional_forward_args is not None

<ipython-input-3-94a740a2c519> in sequence_forward_func(inputs, model, tok_type_ids, att_mask)
      1 def sequence_forward_func(inputs, model, tok_type_ids, att_mask):
      2     """Passes forward the inputs and relevant keyword arguments."""
----> 3     outputs = model(inputs, token_type_ids=tok_type_ids, attention_mask=att_mask)
      4     return outputs

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/transformers/modeling_xlnet.py in forward(self, input_ids, attention_mask, mems, perm_mask, target_mapping, token_type_ids, input_mask, head_mask, inputs_embeds, use_cache, labels, output_attentions, output_hidden_states)
   1200             heads.
   1201         """
-> 1202         transformer_outputs = self.transformer(
   1203             input_ids,
   1204             attention_mask=attention_mask,

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/transformers/modeling_xlnet.py in forward(self, input_ids, attention_mask, mems, perm_mask, target_mapping, token_type_ids, input_mask, head_mask, inputs_embeds, use_cache, output_attentions, output_hidden_states)
    930                 hidden_states.append((output_h, output_g) if output_g is not None else output_h)
    931 
--> 932             outputs = layer_module(
    933                 output_h,
    934                 output_g,

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/transformers/modeling_xlnet.py in forward(self, output_h, output_g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask, output_attentions)
    495         output_attentions=False,
    496     ):
--> 497         outputs = self.rel_attn(
    498             output_h,
    499             output_g,

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/transformers/modeling_xlnet.py in forward(self, h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask, output_attentions)
    428 
    429             # core attention ops
--> 430             attn_vec = self.rel_attn_core(
    431                 q_head_h,
    432                 k_head_h,

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/transformers/modeling_xlnet.py in rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions)
    270 
    271         # position based attention score
--> 272         bd = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_r_bias, k_head_r)
    273         bd = self.rel_shift_bnij(bd, klen=ac.shape[3])
    274 

~/.conda/envs/torchtext-scripts/lib/python3.8/site-packages/torch/functional.py in einsum(equation, *operands)
    293         print(operand.shape)
    294     print(equation)
--> 295     return _VF.einsum(equation, operands)
    296 
    297 

RuntimeError: size of dimension does not match previous size, operand 1, dim 1

Using captum==0.2.0 , transformers==3.0.2, and torch==1.5.1

p16i commented 4 years ago

I've recently reported an issue that I think it might be relevant to yours: https://github.com/pytorch/captum/issues/427.

kh8fb commented 4 years ago

@heytitle you are absolutely correct. I saw your issue but did not connect the dots so thank you for helping me.

The issue was that XLNet is not a batch_first setup.

For anyone else facing these problems I decided to create a dummy nn.module that permutes the results from my embedding layer (len, batch_size, embeddings) to have the batch_size first (batch_size, len, embeddings). This module goes right after the embeddings in the forward method. A second module then permutes it back to (len, batch_size, embeddings). I set up LayerIntegratedGradients on the first module which allows the batch_first to work correctly.

Thanks again to @heytitle for pointing this out.