huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.05k stars 26.3k forks source link

Gemma 2 returns NaN when using default attn (sdpa) with padding #32390

Open chanind opened 1 month ago

chanind commented 1 month ago

System Info

Python 3.10 Transformers 4.43.3 Linux (Colab notebook)

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

The default gemma 2 2b attn results in NaN for padding tokens. A simple demo can be seen below (also reproduced in this colab notebook):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

inputs = tokenizer(["Hello I am a couch", "cats"], return_tensors="pt", padding=True).to('cuda')
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

print(outputs.logits)

This returns the following

tensor([[[-24.3121,  -8.7513,  -6.9736,  ..., -18.3960, -17.4268, -24.3171],
         [-16.8873,  -4.7767,   5.8828,  ...,  -9.4981,  -9.3307, -16.7723],
         [-18.3313,   1.3191,  -4.6598,  ...,  -2.4244,   1.6774, -18.2153],
         [-18.9110,  -5.8708, -11.7827,  ...,  -5.6606,  -4.2607, -18.8535],
         [-20.1359,  -8.4194, -15.1834,  ..., -13.0231, -11.8288, -19.9716],
         [-16.8807,   5.8885,   0.1881,  ...,  -3.7045,  -6.0659, -16.8421]],
        [[     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
         [     nan,      nan,      nan,  ...,      nan,      nan,      nan]]],
       device='cuda:0')

This can be fixed by changing the attn_implementation to anything except sdpa

Expected behavior

Using padding should not result in NaN for normal inputs to gemma 2 2b

qubvel commented 1 month ago

Hi @chanind, thanks for reporting the issue!

This is indeed a problem of scaled_dot_product_attention in PyTorch

The cause of nan is how softmax is computed over full-masked rows in the attention mask and I hope it will be fixed in future versions of PyTorch, here is a related PR

Also, a similar issue has been reported previously

Besides switching to eager/flash_attnetion_2 you could also try

  1. Use float16 dtype.

    model = AutoModelForCausalLM.from_pretrained(
     "google/gemma-2-2b", device_map="auto", torch_dtype=torch.float16
    )
  2. Modify attn_mask min value.

As suggested in the above issue, we can modify attn_mask to use another min value instead of torch.finfo(dtype).min, for example, torch.finfo(dtype).min / 2. To apply this, find min_dtype = torch.finfo(dtype).min in gemma modeling file and replace it with torch.finfo(dtype).min / 2.

Meanwhile, we will try to fix it on our side, thanks!

ArthurZucker commented 1 month ago

More than this, it's expected as the sdpa path does not support logit soft-capping (For Gemma2). We do already take into account the sdpa bug when creating the mask @qubvel see here: https://github.com/huggingface/transformers/blob/c1aa0edb48217f416f4bbe6e3a9db1500284513b/src/transformers/models/llama/modeling_llama.py#L1063-L1072

Which should be propagated to Gemma2. (it was not there for some reason my bad here)

ArthurZucker commented 1 month ago

Related to #31303

qubvel commented 1 month ago

@ArthurZucker thanks for the updated info!

yaolu-zjut commented 1 month ago

Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan: image so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings: image

EMZEDI commented 1 month ago

Same issue here running the code for hooking the activations of the model. Using float16 made it work.

ArthurZucker commented 3 weeks ago

Hey! Make sure you are using eager or flash_attention_2 not sdpa!

GeekerSsy commented 3 weeks ago

Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan: image so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings: image

hi i have the same issue. How do you solve it? 😊

yaolu-zjut commented 2 weeks ago

Hi, I have met a problem, when I finetune Gemma2-2b using trainsformers.trainer, I find the lr is always 0, and grad_norm is nan: image so what's wrong? I using the same code to finetune llama3-8b and it works well. This is my settings: image

hi i have the same issue. How do you solve it? 😊

Hi, I just use eager instead of sdpa like this: model = AutoModelForCausalLM.from_pretrained(args.prune_model_path, trust_remote_code=True, device_map=device_map, attn_implementation="eager" )