huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.87k stars 26.77k forks source link

RecurrentGemma Doesn't Support left padding? #31201

Closed godjw closed 2 months ago

godjw commented 4 months ago

System Info

Who can help?

@ArthurZucker @younesbelkada

Information

Tasks

Reproduction

I am currently trying to finetune the RecurrentGemmaModel.

However, when I train the model with left padding the training outputs where quite strange so I tried to debug it.

I checked this part about the causal mask. https://github.com/huggingface/transformers/blob/96eb06286b63c9c93334d507e632c175d6ba8b28/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L753-L779

When I put a left padded dummy attention mask like below, the attention mask looks strange.

dummy_attention_mask = tensor([[0, 0, 1, 1],
        [1, 1, 1, 1]], device='cuda:0')

inputs_embeds = model.model.embed_tokens(input_ids.to("cuda"))
hidden_states = inputs_embeds
cache_position = torch.arange(hidden_states.shape[1])

mask_output = model.model._update_causal_mask(dummy_attention_mask, attention_mask.to('cuda'), inputs_embeds, cache_position.to("cuda"))
print(mask_output)
tensor([[[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
          [-3.4028e+38, -3.4028e+38, -0.0000e+00, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -0.0000e+00, -0.0000e+00]]],

        [[[-0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-0.0000e+00, -0.0000e+00, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00]]]],
       device='cuda:0')

However, giving the right padded attention mask produces the right attention mask.

dummy_attention_mask = tensor([[ 1,  1, 0, 0],
        [ 1, 1, 1, 1]], device='cuda:0')

# output
tensor([[[[-0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-0.0000e+00, -0.0000e+00, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00, -0.0000e+00, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38]]],

        [[[-0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-0.0000e+00, -0.0000e+00, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00]]]],

Expected behavior

I think The expected first mask which is left padded should look like below, because the padded parts should not be attended.

# Expected Mask
[[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -0.0000e+00, -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -0.0000e+00, -0.0000e+00]]],
molbap commented 4 months ago

Hi @godjw and thanks for the issue, indeed, it seems -at least from my local setup- that left-padding has issues in float16 precision, does that match your finetuning setup? The issue does seem to arise since v4.40 where we introduced RecurrentGemma, despite having a test in modeling that does check for left-padded generation. ccing @ArthurZucker for vis and will take a look tomorrow!

godjw commented 4 months ago

@molbap Thanks for the reply! and I am using the float32 precision for my finetuning setup. But I think the precisions will not make much difference for the padding mask.

Currently the finetuning only works with right padding

ArthurZucker commented 4 months ago

This seems expected in terms of mask creation:

Now the causal mask is not going to be an issue TBH, but the way the model works might. Since there is a convolution layer in the RNN part, we might need to make it ignore padding. One more think is which padding token is used, and if the embedding was resized for it, the embedding value needs to be an average of all the model's embeddings

godjw commented 4 months ago

@ArthurZucker I'm sorry, but I'm having some difficulty understanding a few points. Could you please clarify why the first two rows becoming -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38 would not work with sdpa and why switching to full attention wouldn't affect the result? It seems quite unusual to me that the padding tokens would be attended to.

I was able to train and inference with the model by giving the model no attention mask and just using the default causal mask, which doesn't care about the padding tokens- but with the left padded attention mask only training was possible and the inference results were really bad.

Additionally, I wanted to mention that I am using sdpa because it is the default and only supported type of attention in the RecurrentGemmaModel.

Thank you very much for your help!

ArthurZucker commented 4 months ago

This piece of code in modeling_llama.py should help you:

# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

🤗

godjw commented 4 months ago

Thank you very much! Then the problem could be because the padding tokens are not considered properly, as you've mentioned.

github-actions[bot] commented 3 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.