Open xgal opened 1 month ago
Would be nice if you could provide a repro 😉 but cc @zucchini-nlp anyways, as I think it should be fairly easy to reproduce! 🤗
Hey @xgal !
If I understand you correctly, you mean masking all tokens before special image
token is not correct. In fact we mask all tokens according to the original Meta implementation + paper, where each image should attend only to the text chunk where it is referred to. And the heuristics is that the special image
token acts as an indicator of to which image each text chunk refers to.
Indeed looks a bit weird that in single-image case the mask is not all 1s, but I believe that is the intended behavior since Meta did same thing in their implementation of Mllama inference
👋 @zucchini-nlp
I actually think the text tokens before the image token should be masked out
but I think after cross_attention_mask *= full_text_row_masked_out_mask
you get the whole cross_attn_mask 0's back from the _prepare_cross_attention_mask
call
example:
taking the prompt in the model_card example but shifting the image token with the same rabbit image as in the example and same generate setting
prompt = "If I had to write a haiku <|image|> for this one"
notice all the mask matrix is 0's (-0 and 0)
now if you don't do cross_attention_mask *= full_text_row_masked_out_mask
https://github.com/huggingface/transformers/blob/c1c7e89620b43f0924148e267a6da5d38450ce1f/src/transformers/models/mllama/modeling_mllama.py#L67
you'll get
which IIUC makes more sense cause when you add the -inf to the attn_weights in the attention forward pass and then do softmax it will get 0's
still the full_text_row_masked_out_mask
stays the same so "mask" after mlp
wdyt ? :)
Hmmm, you're right! I just checked out the very first version we had and indeed that is apparently what was intended, since that impl also applies full_text_row_masked_out_mask
on inverted cross attn mask. We would get an attn mask that is masked before image
token if that was applied before inversion actually.
My personal opinion is that it is weird, yes, because I was expecting the model to mask out the tokens before image
. But I am quite unsure because our impl matched original impl on logit level. Let me get back on this, and you can also try asking in the hub discussion in case authors reply there.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers==4.45.2
when preparing the cross_attention_mask in
_prepare_cross_attention_mask
function we get thecross_attn_mask
to the shape of [batch,text_tokens,image_tokens] and values are-inf
for text tokens before<|image|>
token and 0's afterwards now https://github.com/huggingface/transformers/blob/c1c7e89620b43f0924148e267a6da5d38450ce1f/src/transformers/models/mllama/modeling_mllama.py#L64 this line builds a masking that will use for the mlp with 0's for the tokens before the<|image|>
tokens and 1's afterwards and then we multiply thecross_attention_mask
with this new full_text we built so effectively thecross_attention_mask
is always 0's, actually -0 and 0 but its of course the sameI thought the
cross_attention_mask *= full_text_row_masked_out_mask
should be removed https://github.com/huggingface/transformers/blob/c1c7e89620b43f0924148e267a6da5d38450ce1f/src/transformers/models/mllama/modeling_mllama.py#L67and then we we get to the forward pass and head to the cross_attention calculation we never take it in count before the softmax. there is still some use for the attetion_mask comes from the full_text_row_masked I mentioned before, cause after the mlp we multiply by this vector so actually we zero all hidden_states for the tokens before
<|image|>
, still there is some effect comes from the added hidden_states in the residual before mlp. 🤷♂️Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
generate in mllama conditional with <|image|> token not as first token will notice the cross attention mask became all zeros after _prepare_cross_attention_mask called
Expected behavior
partially use the cross attention mask