huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.41k stars 27.09k forks source link

Decoder and cross-attention shape is different when obtained by model.generate() and model() #33296

Closed cgr71ii closed 2 months ago

cgr71ii commented 2 months ago

System Info

Who can help?

@gante

Information

Tasks

Reproduction

Hi!

If you set trigger_error to True, you will see the differences for the decoder-attention (also for the cross-attention) shape when the translation is generated by model.generate() and model(). I don't know if this is a bug or just expected to be different. I have checked that the attention values are the same when all the information is structured the same way (there are differences in precision though, which I think is because model.generate() generates differently than model()).

import torch
import transformers

trigger_error = True # Chante THIS
pretrained_model = "facebook/nllb-200-distilled-600M"
device = "cuda" if torch.cuda.is_available() else "cpu" 
source_lang = "eng_Latn"
target_lang = "spa_Latn"
tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model, src_lang=source_lang, tgt_lang=target_lang)
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained_model).to(device)
source_text = ["Hello!", "Hello again!!!!!!"]
inputs = tokenizer(source_text, return_tensors="pt", add_special_tokens=True, max_length=1024, truncation=True, padding=True).to(device)
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang)
translated_tokens = model.generate(**inputs, forced_bos_token_id=target_lang_id, max_new_tokens=200, num_return_sequences=1, num_beams=1, return_dict_in_generate=True, output_attentions=True)
translated_tokens_model = model(**inputs, decoder_input_ids=translated_tokens.sequences[:,:-1], output_attentions=True)

# Checks
num_hidden_layers = model.config.num_hidden_layers
num_attention_heads = model.config.num_attention_heads
batch_size = len(source_text)

# Encoder
assert len(translated_tokens.encoder_attentions) == num_hidden_layers # 12
assert len(translated_tokens_model.encoder_attentions) == num_hidden_layers # 12

for l in range(num_hidden_layers):
  assert translated_tokens.encoder_attentions[l].shape == translated_tokens_model.encoder_attentions[l].shape
  assert (translated_tokens.encoder_attentions[l] == translated_tokens_model.encoder_attentions[l]).all().cpu().item()

#####

def transformer_attention_to_common_structure(attention_ttg, attention_ttm):
  ## Transform attention from model.generate() to common structure with model()
  decoded_tokens = attention_ttm[0].shape[-2:]
  _decoder_attention = torch.zeros(num_hidden_layers, batch_size, num_attention_heads, *decoded_tokens).to(device)

  for _decoded_tokens, t in enumerate(attention_ttg, 1):
    # Causal mask
    t = torch.stack(t, 0) # (num_hidden_layers, batch_size, num_attention_heads, 1, _decoded_tokens)
    t = t.squeeze(-2)
    _decoder_attention[:,:,:, _decoded_tokens - 1,:t.shape[-1]] = t

  for l in range(num_hidden_layers):
    assert _decoder_attention.shape == (len(attention_ttm), *attention_ttm[l].shape)
#    assert (_decoder_attention[l] == attention_ttm[l]).all().cpu().item() # Differences due to precision (even when device="cpu")... model.generate() is generating different to model()?
    assert torch.isclose(_decoder_attention[l], attention_ttm[l]).all().cpu().item()

# Decoder
assert len(translated_tokens.decoder_attentions) == num_hidden_layers if trigger_error else True
assert len(translated_tokens_model.decoder_attentions) == num_hidden_layers # 12

transformer_attention_to_common_structure(translated_tokens.decoder_attentions, translated_tokens_model.decoder_attentions)

# Cross
assert len(translated_tokens.cross_attentions) == num_hidden_layers if trigger_error else True
assert len(translated_tokens_model.cross_attentions) == num_hidden_layers # 12

transformer_attention_to_common_structure(translated_tokens.cross_attentions, translated_tokens_model.cross_attentions)

Expected behavior

I would expect to have the same format for the decoder and cross-attention shape regardless of where I use model.generate() or model(). Specifically, I would expect to obtain the result from model(), which for the decoder we obtain a matrix for each layer of the shape (batch_size, attention_heads, generated_tokens - 1, generated_tokens - 1).

gante commented 2 months ago

Hi @cgr71ii 👋 Thank you for opening this issue 🤗

As shown in our documentation, the output of generate is different from the output of forward.

Namely, generate's attention output is a tuple where each item is the attention output of one forward pass. In your example, if you replace e.g. translated_tokens.decoder_attentions by translated_tokens.decoder_attentions[0] you'll obtain the results you were expecting :)

cgr71ii commented 2 months ago

Oh, ok! Thank you! :)