Closed kh8fb closed 4 years ago
I've recently reported an issue that I think it might be relevant to yours: https://github.com/pytorch/captum/issues/427.
@heytitle you are absolutely correct. I saw your issue but did not connect the dots so thank you for helping me.
The issue was that XLNet is not a batch_first setup.
For anyone else facing these problems I decided to create a dummy nn.module
that permutes the results from my embedding layer (len, batch_size, embeddings) to have the batch_size first (batch_size, len, embeddings). This module goes right after the embeddings in the forward method. A second module then permutes it back to (len, batch_size, embeddings). I set up LayerIntegratedGradients on the first module which allows the batch_first to work correctly.
Thanks again to @heytitle for pointing this out.
I am trying to use HuggingFace's XLNet
word_embedding
layer in LayerIntegratedGradients to calculate attributions. The method I use is similar to the BERT Question Answering Tutorial, but I am instead trying to use the embedding layer of XLNet. I get an incredibly long (and misleading) stack trace with the error that Einsum has the incorrect dimensions (see minimum working example below). I looked further into the issue and altered the XLNet script to print the dimensions after every internal forward method.The issue comes when the gradients are batched to
n_steps
and passed to the model. Theinputs_layer
andbaselines_layer
function calls both work fine (tensors sized[2, 20]
). During thegradient_func
, when the tensor of size[100, 20]
or[n_steps*btc_size, num_ids]
is passed to theword_embedding
layer, it returns as a tensor of size[1000, 2, 768]
. The expected return shape (from the embeddings) for a tensor sized[100, 20]
is[20, 100, 768]
. I believe that the issue is possibly in the hook that is created for this layer. Are there any workarounds that you guys are aware of to make this work? Any advice would be greatly appreciated.Minimum working example
Error stack trace
Using captum==0.2.0 , transformers==3.0.2, and torch==1.5.1