huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.36k stars 27.09k forks source link

cross attention mask is always zeros in mllama #34280

Open xgal opened 1 month ago

xgal commented 1 month ago

System Info

transformers==4.45.2

when preparing the cross_attention_mask in _prepare_cross_attention_mask function we get thecross_attn_mask to the shape of [batch,text_tokens,image_tokens] and values are -inf for text tokens before <|image|> token and 0's afterwards now https://github.com/huggingface/transformers/blob/c1c7e89620b43f0924148e267a6da5d38450ce1f/src/transformers/models/mllama/modeling_mllama.py#L64 this line builds a masking that will use for the mlp with 0's for the tokens before the <|image|> tokens and 1's afterwards and then we multiply the cross_attention_mask with this new full_text we built so effectively the cross_attention_mask is always 0's, actually -0 and 0 but its of course the same

I thought thecross_attention_mask *= full_text_row_masked_out_mask should be removed https://github.com/huggingface/transformers/blob/c1c7e89620b43f0924148e267a6da5d38450ce1f/src/transformers/models/mllama/modeling_mllama.py#L67

and then we we get to the forward pass and head to the cross_attention calculation we never take it in count before the softmax. there is still some use for the attetion_mask comes from the full_text_row_masked I mentioned before, cause after the mlp we multiply by this vector so actually we zero all hidden_states for the tokens before <|image|>, still there is some effect comes from the added hidden_states in the residual before mlp. 🤷‍♂️

Who can help?

No response

Information

Tasks

Reproduction

generate in mllama conditional with <|image|> token not as first token will notice the cross attention mask became all zeros after _prepare_cross_attention_mask called

Expected behavior

partially use the cross attention mask

ArthurZucker commented 1 month ago

Would be nice if you could provide a repro 😉 but cc @zucchini-nlp anyways, as I think it should be fairly easy to reproduce! 🤗

zucchini-nlp commented 1 month ago

Hey @xgal !

If I understand you correctly, you mean masking all tokens before special image token is not correct. In fact we mask all tokens according to the original Meta implementation + paper, where each image should attend only to the text chunk where it is referred to. And the heuristics is that the special image token acts as an indicator of to which image each text chunk refers to.

Indeed looks a bit weird that in single-image case the mask is not all 1s, but I believe that is the intended behavior since Meta did same thing in their implementation of Mllama inference

xgal commented 1 month ago

👋 @zucchini-nlp I actually think the text tokens before the image token should be masked out but I think after cross_attention_mask *= full_text_row_masked_out_mask you get the whole cross_attn_mask 0's back from the _prepare_cross_attention_mask call example: taking the prompt in the model_card example but shifting the image token with the same rabbit image as in the example and same generate setting prompt = "If I had to write a haiku <|image|> for this one"

image

notice all the mask matrix is 0's (-0 and 0) now if you don't do cross_attention_mask *= full_text_row_masked_out_mask https://github.com/huggingface/transformers/blob/c1c7e89620b43f0924148e267a6da5d38450ce1f/src/transformers/models/mllama/modeling_mllama.py#L67

you'll get

image

which IIUC makes more sense cause when you add the -inf to the attn_weights in the attention forward pass and then do softmax it will get 0's

still the full_text_row_masked_out_mask stays the same so "mask" after mlp

wdyt ? :)

zucchini-nlp commented 1 month ago

Hmmm, you're right! I just checked out the very first version we had and indeed that is apparently what was intended, since that impl also applies full_text_row_masked_out_mask on inverted cross attn mask. We would get an attn mask that is masked before image token if that was applied before inversion actually.

My personal opinion is that it is weird, yes, because I was expecting the model to mask out the tokens before image. But I am quite unsure because our impl matched original impl on logit level. Let me get back on this, and you can also try asking in the hub discussion in case authors reply there.

github-actions[bot] commented 2 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.