huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.41k stars 27.1k forks source link

Gemma 2 returns NaN when using default attn (sdpa) with padding #32390

Closed chanind closed 1 month ago

chanind commented 3 months ago

System Info

Python 3.10 Transformers 4.43.3 Linux (Colab notebook)

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

The default gemma 2 2b attn results in NaN for padding tokens. A simple demo can be seen below (also reproduced in this colab notebook):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

inputs = tokenizer(["Hello I am a couch", "cats"], return_tensors="pt", padding=True).to('cuda')
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

print(outputs.logits)

This returns the following

tensor([[[-24.3121,  -8.7513,  -6.9736,  ..., -18.3960, -17.4268, -24.3171],
         [-16.8873,  -4.7767,   5.8828,  ...,  -9.4981,  -9.3307, -16.7723],
         [-18.3313,   1.3191,  -4.6598,  ...,  -2.4244,   1.6774, -18.2153],
         [-18.9110,  -5.8708, -11.7827,  ...,  -5.6606,  -4.2607, -18.8535],
         [-20.1359,  -8.4194, -15.1834,  ..., -13.0231, -11.8288, -19.9716],
         [-16.8807,   5.8885,   0.1881,  ...,  -3.7045,  -6.0659, -16.8421]],
        [[     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan]]],
       device='cuda:0')

This can be fixed by changing the attn_implementation to anything except sdpa

Expected behavior

Using padding should not result in NaN for normal inputs to gemma 2 2b

qubvel commented 3 months ago

Hi @chanind, thanks for reporting the issue!

This is indeed a problem of scaled_dot_product_attention in PyTorch

The cause of nan is how softmax is computed over full-masked rows in the attention mask and I hope it will be fixed in future versions of PyTorch, here is a related PR

Also, a similar issue has been reported previously

Besides switching to eager/flash_attnetion_2 you could also try

  1. Use float16 dtype.

    model = AutoModelForCausalLM.from_pretrained(
     "google/gemma-2-2b", device_map="auto", torch_dtype=torch.float16
    )
  2. Modify attn_mask min value.

As suggested in the above issue, we can modify attn_mask to use another min value instead of torch.finfo(dtype).min, for example, torch.finfo(dtype).min / 2. To apply this, find min_dtype = torch.finfo(dtype).min in gemma modeling file and replace it with torch.finfo(dtype).min / 2.

Meanwhile, we will try to fix it on our side, thanks!

ArthurZucker commented 3 months ago

More than this, it's expected as the sdpa path does not support logit soft-capping (For Gemma2). We do already take into account the sdpa bug when creating the mask @qubvel see here: https://github.com/huggingface/transformers/blob/c1aa0edb48217f416f4bbe6e3a9db1500284513b/src/transformers/models/llama/modeling_llama.py#L1063-L1072

Which should be propagated to Gemma2. (it was not there for some reason my bad here)

ArthurZucker commented 3 months ago

Related to #31303

qubvel commented 3 months ago

@ArthurZucker thanks for the updated info!

yaolu-zjut commented 3 months ago

Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan: image so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings: image

EMZEDI commented 3 months ago

Same issue here running the code for hooking the activations of the model. Using float16 made it work.

ArthurZucker commented 3 months ago

Hey! Make sure you are using eager or flash_attention_2 not sdpa!

GeekerSsy commented 3 months ago

Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan: image so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings: image

hi i have the same issue. How do you solve it? 😊

yaolu-zjut commented 2 months ago

Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan: image so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings: image

hi i have the same issue. How do you solve it? 😊

Hi, I just use eager instead of sdpa like this: model = AutoModelForCausalLM.from_pretrained(args.prune_model_path, trust_remote_code=True, device_map=device_map, attn_implementation="eager" )

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.