inseq-team / inseq

Interpretability for sequence generation models 🐛 🔍
https://inseq.org
Apache License 2.0
369 stars 36 forks source link

Can't run inseq with batch_size>1 for methods that use gradients (e.g. integrated gradients) #272

Closed issam9 closed 5 months ago

issam9 commented 5 months ago

Running inseq attribute method over a dataset to get attributions with integrated gradients results in the following error:

ValueError: Attention mask should be of size (1, 1, 1, 624), but is torch.Size([16, 1, 1, 39])

Using a batch size of 1 works fine but it's very slow, so I was hoping I can use a larger batch size. In the output above I'm using a batch size of 16.

I see in the attribution_model.py code lines 433-457 that the batch_size is forced to 1 in some cases but I couldn't figure out where exactly the shape of the input is changed.

I think if the attribution can only be done with a batch size of 1 then it would still be better if the previous steps are done with a larger batch size. Maybe that's the reasoning for setting batch_size to 1 before running attribute instead of earlier?

gsarti commented 5 months ago

Hi @issam9, thanks for reporting this. Could you provide the exact code you used to produce the error you mention? With the attribute-dataset CLI command I'm having no problem to attribute a batch with either decoder-only (GPT-2) or encoder-decoder models using integrated gradients.

Here's an example with GPT-2:

inseq attribute-dataset \
  --model_name_or_path gpt2 \
  --attribution_method integrated_gradients \
  --dataset_name inseq/dummy_enit \
  --input_text_field en \
  --dataset_split "train[:20]" \
  --viz_path attributions.html \
  --batch_size 8 \
  --hide \
  --generation_kwargs '{"max_new_tokens": 20}'
image

In general, some methods are indeed constrained to batch size of 1 at the moment, but batched attribution should work for most Captum-based methods, including saliency and integrated gradients.

issam9 commented 5 months ago

Thank you for your reply! I'm sorry but the issue was caused because I updated the code to work on a list of words instead of a sentence, and I set is_split_into_words in the tokenizer to True. Now I solved the issue by setting it to False when as_targets is True.