huggingface / transformers

šŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.05k stars 26.3k forks source link

Difference in LlamaAttention & LlamaFlashAttention2 attn_output #27050

Open ringohoffman opened 10 months ago

ringohoffman commented 10 months ago

System Info

Who can help?

@ArthurZucker and @younesbelkada

Information

Tasks

Reproduction

We notice LlamaFlashAttention2._flash_attention_forward returns a different attn_output than LlamaAttention computes.

flash_attn_non_determinism.py:

import argparse

import torch
import torch.backends.cudnn
import transformers
from transformers.models import llama

def main() -> None:
    torch.backends.cudnn.deterministic = True

    parser = argparse.ArgumentParser()
    parser.add_argument("--use-flash-attention-2", action="store_true")
    args = parser.parse_args()
    use_flash_attention_2 = args.use_flash_attention_2

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        "/models/huggingface/meta-llama/llama-2-7b-chat-hf", local_files_only=True, use_safetensors=True, device_map=torch.device("cuda")
    )
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    text = "Hello world!"
    tokenized_text = tokenizer(text)
    tokenized_text = {key: torch.tensor(value).unsqueeze(dim=0).to(torch.device("cuda")) for key, value in tokenized_text.items()}
    tokenized_text["labels"] = tokenized_text["input_ids"].clone()

    torch.manual_seed(0)
    model = llama.LlamaForCausalLM.from_pretrained(
        "/models/huggingface/meta-llama/llama-2-7b-chat-hf",
        local_files_only=True,
        use_safetensors=True,
        device_map=torch.device("cuda"),
        use_flash_attention_2=use_flash_attention_2,
        torch_dtype=torch.bfloat16,
    )
    assert isinstance(model, llama.LlamaForCausalLM)
    model.eval()
    for param in model.parameters():
        param.requires_grad = False

    model.model.layers[0].train()
    for param in model.model.layers[0].parameters():
        param.requires_grad = True

    optim = torch.optim.AdamW(model.parameters())

    torch.manual_seed(0)

    for i in range(10):
        output = model(**tokenized_text)
        loss = output["loss"]
        if i in (0, 9):
            print(loss)
        loss.backward()
        optim.step()
        optim.zero_grad()

if __name__ == "__main__":
    main()
$ python flash_attn_non_determinism.py --use-flash-attention-2
tensor(5.6612, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.3542, device='cuda:0', grad_fn=<NllLossBackward0>)
$ python flash_attn_non_determinism.py
tensor(5.6589, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.2275, device='cuda:0', grad_fn=<NllLossBackward0>)

Expected behavior

I am not expecting the magnitude of the difference between the 2 implementations. A difference of 0.1267 compared to 0.3542 seems very large.

ArthurZucker commented 10 months ago

Hey, I think this is related to flash attention version, could you have a look at #26697?

KyleMylonakisProtopia commented 10 months ago

We are currently using flash-attn==2.3.2. There was a minor version release of flash attention literally yesterday.

The problem persists with flash-attn==2.3.3.

Are you able to reproduce on your end with the supplied script?

ArthurZucker commented 10 months ago

cc @younesbelkada if you can have a look šŸ˜‰

younesbelkada commented 10 months ago

hi @KyleMylonakisProtopia ! I think that difference is expected, I am not sure if flash-attn guarantees full reproducibility for gradient computation, note also that some slight differences in logits are expected between FA-2 and non FA-2 models.

KyleMylonakisProtopia commented 10 months ago

The code demonstrates non-trivial differences in the loss prior to even the first backwards call. Flash attention and flash attention 2 are supposed to be exact algorithms for computing attention.

From the Flash attention 2 paper "To speed up attention on hardware accelerators such as GPU, [5] proposes an algorithm to reduce the memory reads/writes while maintaining the same output (without approximation)." That seems pretty unambiguous to me.

The slight differences from whatever parallelization differences are happening should not be manifesting at the third significant digit on the first loss call. This points to some other kind of issue.

younesbelkada commented 10 months ago

Flash attention and flash attention 2 are supposed to be exact algorithms for computing attention.

yes, but in the script above you are comparing vanilla attention vs FA-2 no?

KyleMylonakisProtopia commented 10 months ago

That sentence is referring to Flash attention (and implicitly flash attention 2) to "vanilla" attention. That is what our script is showing.

younesbelkada commented 10 months ago

ah correct yes you are right, sorry for the confusion, I'll have a deeper look !

cckao commented 10 months ago

I also encountered the same problem at inference. Environment: transformers==4.34.0, flash-attn==2.3.3, torch==2.0.1+cu117.

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
prompt = """<s>[INST]Tell me the story about a dog.[/INST]"""
d_model = "/path/to/CodeLlama-13b-Instruct-hf"
tokenizer = CodeLlamaTokenizer.from_pretrained(d_model)
model = LlamaForCausalLM.from_pretrained(d_model, device_map="auto", torch_dtype=torch.bfloat16)
tokenized = tokenizer(prompt, return_tensors="pt", truncation=False).to("cuda")
generated_ids = model.generate(**tokenized, max_new_tokens=1024, do_sample=True, streamer=TextStreamer(tokenizer, skip_prompt=True))

use-flash-attention-2=False:

Once upon a time, there was a dog named Max. Max was a lovable golden retriever who loved nothing more than to go for walks with his owner, Sarah. One day, while they were out on a walk,

use-flash-attention-2=True:

Once upon a time, there was a dog named Max. Max was a lovable golden retriever who loved nothing more than to go for walks with his owner, Sarah. One day, while they were out on their usual stroll,

wizyoung commented 10 months ago

Here is my minimal reproducible script:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaModel, _make_causal_mask

device = torch.device("cuda")
dtype = torch.float16

config_ori = LlamaConfig(
    hidden_size=1024,
    intermediate_size=128,
    num_hidden_layers=1,
    num_attention_heads=8,
    max_position_embeddings=16,
    _flash_attn_2_enabled=False
)

config_new = LlamaConfig(
    hidden_size=1024,
    intermediate_size=128,
    num_hidden_layers=1,
    num_attention_heads=8,
    max_position_embeddings=16,
    _flash_attn_2_enabled=True
)

model_ori = LlamaModel(config_ori)
model_new = LlamaModel(config_new)

model_new.load_state_dict(model_ori.state_dict())

model_ori.to(dtype).to(device)
model_new.to(dtype).to(device)

attn_ori = model_ori.layers[0].self_attn
attn_new = model_new.layers[0].self_attn

bsz, hs, seqlen = 2, config_ori.hidden_size, 4
inputs_embeds = torch.randn((bsz, seqlen, hs), dtype=dtype, device=device)

padding_mask = torch.full((bsz, seqlen), 1, dtype=torch.long, device=device)
# or pad a part
# padding_mask[0, 2:] = 0

out_ori = model_ori(attention_mask=padding_mask, inputs_embeds=inputs_embeds, use_cache=False)['last_hidden_state']
out_new = model_new(attention_mask=padding_mask, inputs_embeds=inputs_embeds, use_cache=False)['last_hidden_state']

out_ori.sum(), out_new.sum(), (out_ori - out_new).mean().item(), (out_ori - out_new).abs().max().item(), (out_ori - out_new).abs().mean().item()

I noticed that the numerical difference mainly comes from the padding_mask. If the padding_mask is None, it means we only use the causal mask, and the difference is small. However, if we set the padding_mask, we cannot ignore the difference. image image If we run pytest from the offical flash-attn repo, the diff.abs().max().item() is always small: image

The diff comes from the attention module. A more fine-grained code:

bsz, hs, seqlen = 2, config_ori.hidden_size, 4
hidden = torch.rand((bsz, seqlen, hs), dtype=dtype, device=device)

padding_mask = torch.full((bsz, seqlen), 1, dtype=torch.long, device=device)
# padding_mask[0, 2:] = 0

past_key_values_length = 0
key_value_length = seqlen + past_key_values_length

position_ids = torch.arange(past_key_values_length, key_value_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)

if padding_mask is not None:
    attention_mask_ori = model_ori.attn_mask_converter.to_4d(
        padding_mask, seqlen, key_value_length, dtype=hidden.dtype
    )
else:
    attention_mask_ori = model_ori.attn_mask_converter.to_causal_4d(
        bsz, seqlen, key_value_length, dtype=hidden.dtype, device=hidden.device
    )

out_ori, _, _ = attn_ori.forward(
    hidden, attention_mask=attention_mask_ori, position_ids=position_ids, 
)

out_new, _, _ = attn_new.forward(
    hidden, attention_mask=padding_mask, position_ids=position_ids
)

out_ori.sum(), out_new.sum(), (out_ori - out_new).mean().item(), (out_ori - out_new).abs().max().item(), (out_ori - out_new).abs().mean().item()

UPDATE: It seems the diff lies in the padded part in the final attn weights? So maybe this should not affect the final training loss and the inference results?

my env:

hope this helps!

KyleMylonakisProtopia commented 10 months ago

Thanks for the deep dive @wizyoung! This thread already shows differences in the loss and the inference results, so something is afoot.

ArthurZucker commented 10 months ago

cc @younesbelkada If I remember correctly when we debugged the flash attention tests, we found out that the attention mask was not properly taken into account and the attention weights for pad tokens was non zero in vanilla and zero for flash attention. This came from the way we create our attention mask, which adds two inf values, creating overflows. We should be able to easily fix! cc @patrickvonplaten as we talked about this

wizyoung commented 10 months ago

cc @younesbelkada If I remember correctly when we debugged the flash attention tests, we found out that the attention mask was not properly taken into account and the attention weights for pad tokens was non zero in vanilla and zero for flash attention. This came from the way we create our attention mask, which adds two inf values, creating overflows. We should be able to easily fix! cc @patrickvonplaten as we talked about this

I think maybe this is not the actual cause. As two inf values will not cause much numerical difference after softmax. After applying your fix above, the output of the padded part still differs. image The results indicate that the padding mask does not take effect in computing attention weights.

The problem should come from the pad_input after computing flash attn results.

Update: I ran a quick test on my work projects. In the baseline scenario, I trained and tested everything without using flash attention. For Experiment 1 (Exp1), I trained and tested while using flash attention. The evaluation process involved periodically switching to the test dataset, enabling use_cache=True, and performing batch inference. I noticed that the evaluation metrics in Exp1 were around 20% lower compared to the baseline. However, when I loaded the checkpoint from Exp1 without flash attention, the results were nearly identical to the baseline. This outcome matches my expectations because the discrepancies are mainly caused by padding, which is disregarded during the loss backward process and does not affect convergence. Nevertheless, I'm puzzled about why this would impact inference, as I believe that once the EOS token is predicted in the generation process, the process should be finished.

younesbelkada commented 10 months ago

Thanks a lot @wizyoung for the deep dive!

@ArthurZucker indeed we noticed some discrepencies with respect to padd tokens and I think at that time our conclusion was that

UPDATE: It seems the diff lies in the padded part in the final attn weights? So maybe this should not affect the final training loss and the inference results?

as stated by @wizyoung

younesbelkada commented 10 months ago

The difference clearly resides on the padding tokens.

With FA-2:

(Pdb) self.o_proj(attn_output)
tensor([[[ 0.6187, -0.9595, -0.2783,  ...,  0.1057, -0.5645, -0.3220],
         [ 0.4392, -0.5137, -0.5078,  ...,  0.0863, -0.3232,  0.1931],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.1334,  0.1556, -0.5737,  ..., -0.1802,  0.2262, -0.6035],
         [-0.2883, -0.1821, -0.5303,  ...,  0.2157,  0.0258, -0.0304],
         [-0.4187, -0.1300, -0.2747,  ...,  0.3828,  0.0053, -0.3252],
         [-0.1055,  0.0997, -0.1527,  ...,  0.3984, -0.1208, -0.1553]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<UnsafeViewBackward0>)

Without FA-2:

tensor([[[ 0.6187, -0.9595, -0.2783,  ...,  0.1057, -0.5645, -0.3220],
         [ 0.4392, -0.5137, -0.5078,  ...,  0.0862, -0.3232,  0.1930],
         [ 0.4172, -0.4719, -0.4473,  ..., -0.1212, -0.3323,  0.0089],
         [ 0.5713, -0.4893, -0.4084,  ..., -0.0648, -0.3967, -0.0724]],

        [[ 0.1334,  0.1556, -0.5737,  ..., -0.1802,  0.2262, -0.6035],
         [-0.2883, -0.1821, -0.5303,  ...,  0.2156,  0.0258, -0.0306],
         [-0.4187, -0.1299, -0.2747,  ...,  0.3828,  0.0053, -0.3252],
         [-0.1055,  0.0997, -0.1527,  ...,  0.3987, -0.1210, -0.1554]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<UnsafeViewBackward0>)

As you can see, the hidden states that corresponds to the indices of the attention mask:

(Pdb) attention_mask
tensor([[1, 1, 0, 0],
        [1, 1, 1, 1]], device='cuda:0')

I also tried #27114

Are zero-ed out for FA2 whereas they're not for non-FA2 models. Will investigate more

younesbelkada commented 10 months ago

Hi everyone we had a deeper look with @ArthurZucker and here are our findings:

1- #27114 fixes another issue we have with all attention modules in transformers when combining attention masks together, leading sometimes to have undesired inf values inside these masks.

2- for resolving the issue mentioned in the snippet of @wizyoung the adding the following inside the attention module:

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

+ if attention_mask is not None:
+    sliced_attention_mask = attention_mask[:, 0, -1, :]
+    attention_mask_2d = (1.0 * ~sliced_attention_mask.bool()).to(attn_output.dtype)
+    attn_output = attn_output * attention_mask_2d.unsqueeze(-1)

if self.config.pretraining_tp > 1:
    attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
    o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
    attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
    attn_output = self.o_proj(attn_output)

if not output_attentions:
    attn_weights = None

return attn_output, attn_weights, past_key_value

Fixes the issue as this snippet correctly zeroes-out all the hidden states that are related to padding tokens. I am not sure this leads to any impact for generation. Given also that the slicing + cast operations can add some considerable overhead in the attention module (as it has to be done for every layer) I am not sure we should upstream these changes in transformers core.

However the issue mentioned by @KyleMylonakisProtopia still persists (I am able to repro even with the fix), which needs further investigation

KyleMylonakisProtopia commented 10 months ago

Thanks for the continued look!

patrickvonplaten commented 10 months ago

I think the reason for the discrepancy between FA-2 and non-FA-2 here comes solely from the fact that we're comparing padded output tensors and/or included padded hidden states vectors in our results. Padded hidden states vectors are useless/moot and should never influence a loss or be compared.

Let's explain a bit:

  1. Padded hidden states vectors are vectors that correspond to a sequence index i that is not attended to meaning attention_mask[i] is 0. This corresponds to the outer-most left tokens here: https://github.com/huggingface/transformers/issues/27050#issue-1960195010 since we use left-padding or all tokens after 2: here: https://github.com/huggingface/transformers/issues/27050#issuecomment-1782529853 when doing right padding.

  2. One should never take padded hidden states vectors into account! One should never never compare padded hidden states vectors to each other because they are moot / useless and should never be used. This means when comparing the loss here: https://github.com/huggingface/transformers/issues/27050#issue-1960195010, one has to add -100 to the labels indexes that correspond to padding tokens to make sure they don't influence the loss. See this issue as well. Similarly it doesn't make much sense to do this fix here: https://github.com/huggingface/transformers/issues/27050#issuecomment-1795089451 and to compare the outputs of padded tokens because they are useless anyways.

  3. What is going on here?!

Let's look at a tiny code example that explains the behavior of non-FA2 code.

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

attention_mask = torch.tensor([[0, 1, 1]])  # left padding

print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0))

We get

tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38,  0.0000e+00, -3.4028e+38],
          [-3.4028e+38,  0.0000e+00,  0.0000e+00]]]])

as expected. We see the causal mask and in addition we see that the first column has high negative values.

Now let's run a softmax on the attention output corresonding to Softmax(QK^T) assuming that QK^T is 1s only.

print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0).softmax(-2)
tensor([[[[0.3333, 0.0000, 0.0000],
          [0.3333, 0.5000, 0.0000],
          [0.3333, 0.5000, 1.0000]]]])

As we can see we put equal weight on all input tokens for the output of the padded hidden states vector. This means the output of the padded hidden states vector is very much not 0.

FA-2 on the other hand just doesn't compute these outputs at all or forces them to be 0 which creates the difference.

Summary

Long story short, let's make sure to not compare outputs of padded hidden states. These states are moot no matter what and should not be used for anything.

It would be great to re-run the little experiment here but making sure that -100 is provided for padded out tokens.

KyleMylonakisProtopia commented 10 months ago

@patrickvonplaten Would you recommend us using Flash Attention 2 then over the default attention until this bug fix lands?

patrickvonplaten commented 10 months ago

I don't think there is a bug at all tbh. Padding tokens are expected to differ between FA2 and vanilla attention. Even when only comparing non-padding tokens there will be minor differences due to the different CUDA kernels being used (but they should not be as big as shown here: https://github.com/huggingface/transformers/issues/27050#issue-1960195010)

Generally, I always recommend using FA2 if you can use it

wizyoung commented 10 months ago

I don't think there is a bug at all tbh. Padding tokens are expected to differ between FA2 and vanilla attention. Even when only comparing non-padding tokens there will be minor differences due to the different CUDA kernels being used (but they should not be as big as shown here: #27050 (comment))

Generally, I always recommend using FA2 if you can use it

I agree. The padded part should not affect the training loss and inference result. In my experiments, training with FA2 but test with vanilla attention does not make any affects at all. But the creepy thing is, training and test with FA2 yields poor results (but the weights is ok if I switch to vanilla attention at test). I see many issues also report the test result discrepancy when using model.generate. Just a guess, maybe we should conduct a more in-depth investigation into the post-process in model.generate?

KyleMylonakisProtopia commented 10 months ago

I think the reason for the discrepancy between FA-2 and non-FA-2 here comes solely from the fact that we're comparing padded output tensors and/or included padded hidden states vectors in our results. Padded hidden states vectors are useless/moot and should never influence a loss or be compared.

Let's explain a bit:

  1. Padded hidden states vectors are vectors that correspond to a sequence index i that is not attended to meaning attention_mask[i] is 0. This corresponds to the outer-most left tokens here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) since we use left-padding or all tokens after 2: here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) when doing right padding.
  2. One should never take padded hidden states vectors into account! One should never never compare padded hidden states vectors to each other because they are moot / useless and should never be used. This means when comparing the loss here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment), one has to add -100 to the labels indexes that correspond to padding tokens to make sure they don't influence the loss. See this issue as well. Similarly it doesn't make much sense to do this fix here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) and to compare the outputs of padded tokens because they are useless anyways.
  3. What is going on here?!

Let's look at a tiny code example that explains the behavior of non-FA2 code.

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

attention_mask = torch.tensor([[0, 1, 1]])  # left padding

print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0))

We get

tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38,  0.0000e+00, -3.4028e+38],
          [-3.4028e+38,  0.0000e+00,  0.0000e+00]]]])

as expected. We see the causal mask and in addition we see that the first column has high negative values.

Now let's run a softmax on the attention output corresonding to Softmax(QK^T) assuming that QK^T is 1s only.

print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0).softmax(-2)
tensor([[[[0.3333, 0.0000, 0.0000],
          [0.3333, 0.5000, 0.0000],
          [0.3333, 0.5000, 1.0000]]]])

As we can see we put equal weight on all input tokens for the output of the padded hidden states vector. This means the output of the padded hidden states vector is very much not 0.

FA-2 on the other hand just doesn't compute these outputs at all or forces them to be 0 which creates the difference.

Summary

Long story short, let's make sure to not compare outputs of padded hidden states. These states are moot no matter what and should not be used for anything.

It would be great to re-run the little experiment here but making sure that -100 is provided for padded out tokens.

So I understand what you are saying and agree, the padded tokens and hidden states should not be used at any point. However I disagree with your conclusion that no bug is necessarily present.

The example provided at the top of this thread does not have padding. If padding is being added and being used anywhere, that is happening in the Huggingface code. Moreover, the loss function we are reporting is the loss function by the Huggingface LLama2 model, again not something that we are writing. If there is a mistake in what we are doing, then we should be able to call out a specific line number in the script at the top of the page where a mistake is made, but I am really having a hard time finding one there. Otherwise whatever is causing the discrepancy would be part of either the Huggingface code, or the code distributed by Meta and hosted on Huggingface.

zhipeng93 commented 10 months ago

I think the reason for the discrepancy between FA-2 and non-FA-2 here comes solely from the fact that we're comparing padded output tensors and/or included padded hidden states vectors in our results. Padded hidden states vectors are useless/moot and should never influence a loss or be compared.

Let's explain a bit:

  1. Padded hidden states vectors are vectors that correspond to a sequence index i that is not attended to meaning attention_mask[i] is 0. This corresponds to the outer-most left tokens here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) since we use left-padding or all tokens after 2: here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) when doing right padding.
  2. One should never take padded hidden states vectors into account! One should never never compare padded hidden states vectors to each other because they are moot / useless and should never be used. This means when comparing the loss here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment), one has to add -100 to the labels indexes that correspond to padding tokens to make sure they don't influence the loss. See this issue as well. Similarly it doesn't make much sense to do this fix here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) and to compare the outputs of padded tokens because they are useless anyways.
  3. What is going on here?!

Let's look at a tiny code example that explains the behavior of non-FA2 code.

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

attention_mask = torch.tensor([[0, 1, 1]])  # left padding

print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0))

We get

tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38],
          [-3.4028e+38,  0.0000e+00, -3.4028e+38],
          [-3.4028e+38,  0.0000e+00,  0.0000e+00]]]])

as expected. We see the causal mask and in addition we see that the first column has high negative values.

Now let's run a softmax on the attention output corresonding to Softmax(QK^T) assuming that QK^T is 1s only.

print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0).softmax(-2)
tensor([[[[0.3333, 0.0000, 0.0000],
          [0.3333, 0.5000, 0.0000],
          [0.3333, 0.5000, 1.0000]]]])

As we can see we put equal weight on all input tokens for the output of the padded hidden states vector. This means the output of the padded hidden states vector is very much not 0.

FA-2 on the other hand just doesn't compute these outputs at all or forces them to be 0 which creates the difference.

Summary

Long story short, let's make sure to not compare outputs of padded hidden states. These states are moot no matter what and should not be used for anything.

It would be great to re-run the little experiment here but making sure that -100 is provided for padded out tokens.

Hi @patrickvonplaten , thanks for the detail explanation :) I agree that the attention output of using FA or not is different.

However, as we know that we are doing a linear projection for the attention output output = linear_proj(attn_output), which is essentially a matmul, output = matmul(attn_output, weight). So the output is indeed affected by the moot part.

zhipeng93 commented 10 months ago

cc https://github.com/huggingface/transformers/pull/26421

patrickvonplaten commented 10 months ago

Trying to narrow down the problem:

It would be great to test: a) Training. @KyleMylonakisProtopia To make sure FA2 vs. no-FA2 influences training we need to make sure to add -100 to padded tokens as follows.

import argparse

import torch
import torch.backends.cudnn
import transformers
from transformers.models import llama

def main() -> None:
    torch.backends.cudnn.deterministic = True

    parser = argparse.ArgumentParser()
    parser.add_argument("--use-flash-attention-2", action="store_true")
    args = parser.parse_args()
    use_flash_attention_2 = args.use_flash_attention_2

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        "/models/huggingface/meta-llama/llama-2-7b-chat-hf", local_files_only=True, use_safetensors=True, device_map=torch.device("cuda")
    )
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    text = "Hello world!"
    tokenized_text = tokenizer(text)
    tokenized_text = {key: torch.tensor(value).unsqueeze(dim=0).to(torch.device("cuda")) for key, value in tokenized_text.items()}
    tokenized_text["labels"] = tokenized_text["input_ids"].clone()
+  tokenized_text["labels"] = torch.where(attention_mask == 0, -100,  tokenized_text["labels"])  # make sure to not apply loss on padded tokens 

    torch.manual_seed(0)
    model = llama.LlamaForCausalLM.from_pretrained(
        "/models/huggingface/meta-llama/llama-2-7b-chat-hf",
        local_files_only=True,
        use_safetensors=True,
        device_map=torch.device("cuda"),
        use_flash_attention_2=use_flash_attention_2,
        torch_dtype=torch.bfloat16,
    )
    assert isinstance(model, llama.LlamaForCausalLM)
    model.eval()
    for param in model.parameters():
        param.requires_grad = False

    model.model.layers[0].train()
    for param in model.model.layers[0].parameters():
        param.requires_grad = True

    optim = torch.optim.AdamW(model.parameters())

    torch.manual_seed(0)

    for i in range(10):
        output = model(**tokenized_text)
        loss = output["loss"]
        if i in (0, 9):
            print(loss)
        loss.backward()
        optim.step()
        optim.zero_grad()

if __name__ == "__main__":
    main()

In addition it won't be enough to loss at the loss curve and say there is a bug if they differ. They will surely differ since the backward method of FA2 is very different to no FA2. We need to actually train Llama on a bit of data and see how quickly the models learn depending on whether FA2 or no FA2 is implemented.

b) It would be great to have a fully reproducible code snippet where we're clearly seeing different results between FA2 and no FA2 for generate. This should be easy to do. Just run both models with the same seed and generate and find an example where they significantly differ or where one is clearly better then the other one.

However, as we know that we are doing a linear projection for the attention output output = linear_proj(attn_output), which is essentially a matmul, output = matmul(attn_output, weight). So the output is indeed affected by the moot part.

This is not really correct because linear_proj is "seq-len"-independent. Image that the Softmax(QK^T) output is as follows:

attn_output = [vec1, vec2, vec3, pad_vec, pad_vec, vec4]

Now doing:

output = matmul([vec1, vec2, vec3, pad_vec, pad_vec, vec4], weight)

will give you again:

[new_vec1, new_vec2, new_vec3, new_pad_vec, new_pad_vec, new_vec4]

whereby importantly new_pad_vec did not influence the computation of new_vec1 at all. The linear projection is only applied over the feature dimension not the seq dimensions, hence you could also do the following and get the same results:

for i in range(6)
    new_vec_{i} = matmul(vec_{i}, weight)

To make some progress here it would be really really great if someone could provide a reproducible code snippet of either a) or b)

patrickvonplaten commented 10 months ago

This example of cckao is a great example of b): https://github.com/huggingface/transformers/issues/27050#issuecomment-1780424209

But I don't think it's a bug and simply due to different CUDA kernels being used. Note how similar the generations are. You would find similar differences just by running the same algorithm on a different hardware. If there would be a bug with the attention mask, the differences would be much starker.

KyleMylonakisProtopia commented 10 months ago

Just ran with the additional line of code you suggested and unfortunately there was no change in the behavior. The discrepancy remains exactly as it was.

You mean the implementation of the backwards for FA2 is very different to the implementation of the method without FA2. The implementations with and without FA2 are both exact in the sense they are not performing any numerical approximations of the derivative. The sources of error would be truncation error and the non-associativity and commutativity of floating point numbers. Now it could be that that very small error accumulates rapidly due to lack of stability. If that were the case the decreasing the lr, say to 5e-5, and running out to 50 iterations should diminish the discrepancy. However when I do that I see even starker differences at iteration 50.

python flash_attn_non_determinism.py
tensor(5.6589, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.5236, device='cuda:0', grad_fn=<NllLossBackward0>)

python flash_attn_non_determinism.py --use-flash-attention-2
tensor(5.6612, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.4144, device='cuda:0', grad_fn=<NllLossBackward0>)

We know they modes are training differently with and without FA2 already because that's why we made this ticket in the first place: we were not able to reproduce the same results that we had previously established without FA2 after enabling it.

patrickvonplaten commented 10 months ago

Thanks for re-running the training script @KyleMylonakisProtopia ! And in your training experiments, using FA2 doesn't give sensible results where as not using FA2 for training does? Also, it seems like both work correctly in inference no?

=> So could it be that there is then a bug with FA2 for training only?

KyleMylonakisProtopia commented 10 months ago

Have we looked at how the gradients of attention vs. flash attention 2 are backpropagating?

cckao commented 10 months ago

This example of cckao is a great example of b): #27050 (comment)

But I don't think it's a bug and simply due to different CUDA kernels being used. Note how similar the generations are. You would find similar differences just by running the same algorithm on a different hardware. If there would be a bug with the attention mask, the differences would be much starker.

The difference actually becomes quite sensible after the first different token due to the nature of autoregressive models. If the difference is due to different CUDA kernels and we cannot fix it, that really limited the application of FA2 to pretrained models.

github-actions[bot] commented 9 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

KyleMylonakisProtopia commented 9 months ago

@patrickvonplaten This issue is still relevant. Have you been able to look at the gradients?