huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.08k stars 26.81k forks source link

sdpa for bert casues nan when using bfloat16 with padding. #31035

Open Leoyzen opened 5 months ago

Leoyzen commented 5 months ago

System Info

Who can help?

No response

Information

Tasks

Reproduction


In [1]: import torch.nn.functional as F

In [2]: import torch

In [3]: data = torch.load("reproduce_data.pt", map_location='cuda')

In [4]: with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
   ...:     print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], data['attn_mask'], data['dropout_p'], data['is_causal']).isnan().any((1
   ...: ,2)))
   ...:
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],               # <----- without pandding
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],               # <----- with pandding
        [ True,  True,  True,  ...,  True,  True,  True],               # <----- with pandding
        [ True,  True,  True,  ...,  True,  True,  True]], device='cuda:0')

In [5]: with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
   ...:     print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], data['attn_mask'], data['dropout_p'], data['is_causal']).isnan().any((1
   ...: ,2)))
   ...:
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]], device='cuda:0')

# it's okay with offical math kernel
In [6]: with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
   ...:     print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], data['attn_mask'], data['dropout_p'], data['is_causal']).isnan().any((1
   ...: ,2)))
   ...:
/root/.local/share/conda/envs/bytednlp/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:342: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.
  warnings.warn(
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],               # <----- with pandding, which is correct
        [False, False, False,  ..., False, False, False],               # <----- with pandding, which is correct
        [False, False, False,  ..., False, False, False]], device='cuda:0')

In [7]: mask = data['attn_mask']

# we use min / 2 as float('-inf')
In [10]: mask2 = mask.masked_fill(mask.bool(), torch.finfo(mask.dtype).min / 2)

In [11]: with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
    ...:     print(F.scaled_dot_product_attention(data['q'], data['k'], data['v'], mask2, data['dropout_p'], data['is_causal']).isnan().any((1,2)))
    ...:
/root/.local/share/conda/envs/bytednlp/lib/python3.10/site-packages/torch/backends/cuda/__init__.py:342: FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.
  warnings.warn(
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],              # <----- with pandding, which the result is correct.
        [False, False, False,  ..., False, False, False],               # <----- with pandding, which the result is correct.
        [False, False, False,  ..., False, False, False]], device='cuda:0')

Expected behavior

the output should without nan when using bfloat16 and sdap enabled.

I think it is safe to use torch.finfo(dtype).min / 2 instead of torch.finfo(dtype.min.

amyeroberts commented 5 months ago

c @ArthurZucker @fxmarty

ArthurZucker commented 4 months ago

yep, this was always there I think and related to the mask creation that probably overflows. Do you want to open a PR for a fix? 🤗

fxmarty commented 4 months ago

@Leoyzen can you share the repro tensors and/or reproduction with a transformers example?

Leoyzen commented 4 months ago

@Leoyzen can you share the repro tensors and/or reproduction with a transformers example?

The reproduce_ata.pt which dumped from the private code repo is quite large(with (torch.Size([31, 1, 2000, 2000])) and almost 1GB).

We use Bert from transformers and weights from (stella-v2)[https://huggingface.co/infgrad/stella-large-zh-v2] to do some finetuning work.

Training with bert large and bfloat16 should reproduce the bug.

reproduce.zip

ArthurZucker commented 2 months ago

Mmmm if this is still a problem, we need to propagate the changes from #32227 to bert and bert sdpa!

ArthurZucker commented 2 months ago

Leaving it to the community unless I get time !