Open ringohoffman opened 10 months ago
Hey, I think this is related to flash attention version, could you have a look at #26697?
We are currently using flash-attn==2.3.2
. There was a minor version release of flash attention literally yesterday.
The problem persists with flash-attn==2.3.3
.
Are you able to reproduce on your end with the supplied script?
cc @younesbelkada if you can have a look š
hi @KyleMylonakisProtopia ! I think that difference is expected, I am not sure if flash-attn guarantees full reproducibility for gradient computation, note also that some slight differences in logits are expected between FA-2 and non FA-2 models.
The code demonstrates non-trivial differences in the loss prior to even the first backwards call. Flash attention and flash attention 2 are supposed to be exact algorithms for computing attention.
From the Flash attention 2 paper "To speed up attention on hardware accelerators such as GPU, [5] proposes an algorithm to reduce the memory reads/writes while maintaining the same output (without approximation)." That seems pretty unambiguous to me.
The slight differences from whatever parallelization differences are happening should not be manifesting at the third significant digit on the first loss call. This points to some other kind of issue.
Flash attention and flash attention 2 are supposed to be exact algorithms for computing attention.
yes, but in the script above you are comparing vanilla attention vs FA-2 no?
That sentence is referring to Flash attention (and implicitly flash attention 2) to "vanilla" attention. That is what our script is showing.
ah correct yes you are right, sorry for the confusion, I'll have a deeper look !
I also encountered the same problem at inference. Environment: transformers==4.34.0
, flash-attn==2.3.3
, torch==2.0.1+cu117
.
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
prompt = """<s>[INST]Tell me the story about a dog.[/INST]"""
d_model = "/path/to/CodeLlama-13b-Instruct-hf"
tokenizer = CodeLlamaTokenizer.from_pretrained(d_model)
model = LlamaForCausalLM.from_pretrained(d_model, device_map="auto", torch_dtype=torch.bfloat16)
tokenized = tokenizer(prompt, return_tensors="pt", truncation=False).to("cuda")
generated_ids = model.generate(**tokenized, max_new_tokens=1024, do_sample=True, streamer=TextStreamer(tokenizer, skip_prompt=True))
use-flash-attention-2=False:
Once upon a time, there was a dog named Max. Max was a lovable golden retriever who loved nothing more than to go for walks with his owner, Sarah. One day, while they were out on a walk,
use-flash-attention-2=True:
Once upon a time, there was a dog named Max. Max was a lovable golden retriever who loved nothing more than to go for walks with his owner, Sarah. One day, while they were out on their usual stroll,
Here is my minimal reproducible script:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaModel, _make_causal_mask
device = torch.device("cuda")
dtype = torch.float16
config_ori = LlamaConfig(
hidden_size=1024,
intermediate_size=128,
num_hidden_layers=1,
num_attention_heads=8,
max_position_embeddings=16,
_flash_attn_2_enabled=False
)
config_new = LlamaConfig(
hidden_size=1024,
intermediate_size=128,
num_hidden_layers=1,
num_attention_heads=8,
max_position_embeddings=16,
_flash_attn_2_enabled=True
)
model_ori = LlamaModel(config_ori)
model_new = LlamaModel(config_new)
model_new.load_state_dict(model_ori.state_dict())
model_ori.to(dtype).to(device)
model_new.to(dtype).to(device)
attn_ori = model_ori.layers[0].self_attn
attn_new = model_new.layers[0].self_attn
bsz, hs, seqlen = 2, config_ori.hidden_size, 4
inputs_embeds = torch.randn((bsz, seqlen, hs), dtype=dtype, device=device)
padding_mask = torch.full((bsz, seqlen), 1, dtype=torch.long, device=device)
# or pad a part
# padding_mask[0, 2:] = 0
out_ori = model_ori(attention_mask=padding_mask, inputs_embeds=inputs_embeds, use_cache=False)['last_hidden_state']
out_new = model_new(attention_mask=padding_mask, inputs_embeds=inputs_embeds, use_cache=False)['last_hidden_state']
out_ori.sum(), out_new.sum(), (out_ori - out_new).mean().item(), (out_ori - out_new).abs().max().item(), (out_ori - out_new).abs().mean().item()
I noticed that the numerical difference mainly comes from the padding_mask. If the padding_mask is None, it means we only use the causal mask, and the difference is small. However, if we set the padding_mask, we cannot ignore the difference. If we run pytest from the offical flash-attn repo, the diff.abs().max().item() is always small:
The diff comes from the attention module. A more fine-grained code:
bsz, hs, seqlen = 2, config_ori.hidden_size, 4
hidden = torch.rand((bsz, seqlen, hs), dtype=dtype, device=device)
padding_mask = torch.full((bsz, seqlen), 1, dtype=torch.long, device=device)
# padding_mask[0, 2:] = 0
past_key_values_length = 0
key_value_length = seqlen + past_key_values_length
position_ids = torch.arange(past_key_values_length, key_value_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
if padding_mask is not None:
attention_mask_ori = model_ori.attn_mask_converter.to_4d(
padding_mask, seqlen, key_value_length, dtype=hidden.dtype
)
else:
attention_mask_ori = model_ori.attn_mask_converter.to_causal_4d(
bsz, seqlen, key_value_length, dtype=hidden.dtype, device=hidden.device
)
out_ori, _, _ = attn_ori.forward(
hidden, attention_mask=attention_mask_ori, position_ids=position_ids,
)
out_new, _, _ = attn_new.forward(
hidden, attention_mask=padding_mask, position_ids=position_ids
)
out_ori.sum(), out_new.sum(), (out_ori - out_new).mean().item(), (out_ori - out_new).abs().max().item(), (out_ori - out_new).abs().mean().item()
UPDATE: It seems the diff lies in the padded part in the final attn weights? So maybe this should not affect the final training loss and the inference results?
my env:
transformers
version: 4.35.0.dev0 (from commit aa4198a at 2023.10.27 main branch)hope this helps!
Thanks for the deep dive @wizyoung! This thread already shows differences in the loss and the inference results, so something is afoot.
cc @younesbelkada If I remember correctly when we debugged the flash attention tests, we found out that the attention mask was not properly taken into account and the attention weights for pad tokens was non zero in vanilla and zero for flash attention. This came from the way we create our attention mask, which adds two inf values, creating overflows. We should be able to easily fix! cc @patrickvonplaten as we talked about this
cc @younesbelkada If I remember correctly when we debugged the flash attention tests, we found out that the attention mask was not properly taken into account and the attention weights for pad tokens was non zero in vanilla and zero for flash attention. This came from the way we create our attention mask, which adds two inf values, creating overflows. We should be able to easily fix! cc @patrickvonplaten as we talked about this
I think maybe this is not the actual cause. As two inf values will not cause much numerical difference after softmax. After applying your fix above, the output of the padded part still differs. The results indicate that the padding mask does not take effect in computing attention weights.
The problem should come from the pad_input after computing flash attn results.
Update: I ran a quick test on my work projects. In the baseline scenario, I trained and tested everything without using flash attention. For Experiment 1 (Exp1), I trained and tested while using flash attention. The evaluation process involved periodically switching to the test dataset, enabling use_cache=True
, and performing batch inference. I noticed that the evaluation metrics in Exp1 were around 20% lower compared to the baseline. However, when I loaded the checkpoint from Exp1 without flash attention, the results were nearly identical to the baseline. This outcome matches my expectations because the discrepancies are mainly caused by padding, which is disregarded during the loss backward process and does not affect convergence. Nevertheless, I'm puzzled about why this would impact inference, as I believe that once the EOS token is predicted in the generation process, the process should be finished.
Thanks a lot @wizyoung for the deep dive!
@ArthurZucker indeed we noticed some discrepencies with respect to padd tokens and I think at that time our conclusion was that
UPDATE: It seems the diff lies in the padded part in the final attn weights? So maybe this should not affect the final training loss and the inference results?
as stated by @wizyoung
The difference clearly resides on the padding tokens.
With FA-2:
(Pdb) self.o_proj(attn_output)
tensor([[[ 0.6187, -0.9595, -0.2783, ..., 0.1057, -0.5645, -0.3220],
[ 0.4392, -0.5137, -0.5078, ..., 0.0863, -0.3232, 0.1931],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.1334, 0.1556, -0.5737, ..., -0.1802, 0.2262, -0.6035],
[-0.2883, -0.1821, -0.5303, ..., 0.2157, 0.0258, -0.0304],
[-0.4187, -0.1300, -0.2747, ..., 0.3828, 0.0053, -0.3252],
[-0.1055, 0.0997, -0.1527, ..., 0.3984, -0.1208, -0.1553]]],
device='cuda:0', dtype=torch.float16, grad_fn=<UnsafeViewBackward0>)
Without FA-2:
tensor([[[ 0.6187, -0.9595, -0.2783, ..., 0.1057, -0.5645, -0.3220],
[ 0.4392, -0.5137, -0.5078, ..., 0.0862, -0.3232, 0.1930],
[ 0.4172, -0.4719, -0.4473, ..., -0.1212, -0.3323, 0.0089],
[ 0.5713, -0.4893, -0.4084, ..., -0.0648, -0.3967, -0.0724]],
[[ 0.1334, 0.1556, -0.5737, ..., -0.1802, 0.2262, -0.6035],
[-0.2883, -0.1821, -0.5303, ..., 0.2156, 0.0258, -0.0306],
[-0.4187, -0.1299, -0.2747, ..., 0.3828, 0.0053, -0.3252],
[-0.1055, 0.0997, -0.1527, ..., 0.3987, -0.1210, -0.1554]]],
device='cuda:0', dtype=torch.float16, grad_fn=<UnsafeViewBackward0>)
As you can see, the hidden states that corresponds to the indices of the attention mask:
(Pdb) attention_mask
tensor([[1, 1, 0, 0],
[1, 1, 1, 1]], device='cuda:0')
I also tried #27114
Are zero-ed out for FA2 whereas they're not for non-FA2 models. Will investigate more
Hi everyone we had a deeper look with @ArthurZucker and here are our findings:
1- #27114 fixes another issue we have with all attention modules in transformers when combining attention masks together, leading sometimes to have undesired inf
values inside these masks.
2- for resolving the issue mentioned in the snippet of @wizyoung the adding the following inside the attention module:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ if attention_mask is not None:
+ sliced_attention_mask = attention_mask[:, 0, -1, :]
+ attention_mask_2d = (1.0 * ~sliced_attention_mask.bool()).to(attn_output.dtype)
+ attn_output = attn_output * attention_mask_2d.unsqueeze(-1)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
Fixes the issue as this snippet correctly zeroes-out all the hidden states that are related to padding tokens. I am not sure this leads to any impact for generation. Given also that the slicing + cast operations can add some considerable overhead in the attention module (as it has to be done for every layer) I am not sure we should upstream these changes in transformers core.
However the issue mentioned by @KyleMylonakisProtopia still persists (I am able to repro even with the fix), which needs further investigation
Thanks for the continued look!
I think the reason for the discrepancy between FA-2 and non-FA-2 here comes solely from the fact that we're comparing padded output tensors and/or included padded hidden states vectors in our results. Padded hidden states vectors are useless/moot and should never influence a loss or be compared.
Let's explain a bit:
Padded hidden states vectors are vectors that correspond to a sequence index i
that is not attended to meaning attention_mask[i]
is 0. This corresponds to the outer-most left tokens here: https://github.com/huggingface/transformers/issues/27050#issue-1960195010 since we use left-padding or all tokens after 2:
here: https://github.com/huggingface/transformers/issues/27050#issuecomment-1782529853 when doing right padding.
One should never take padded hidden states vectors into account! One should never never compare padded hidden states vectors to each other because they are moot / useless and should never be used. This means when comparing the loss here: https://github.com/huggingface/transformers/issues/27050#issue-1960195010, one has to add -100 to the labels indexes that correspond to padding tokens to make sure they don't influence the loss. See this issue as well. Similarly it doesn't make much sense to do this fix here: https://github.com/huggingface/transformers/issues/27050#issuecomment-1795089451 and to compare the outputs of padded tokens because they are useless anyways.
What is going on here?!
Let's look at a tiny code example that explains the behavior of non-FA2 code.
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
attention_mask = torch.tensor([[0, 1, 1]]) # left padding
print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0))
We get
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38],
[-3.4028e+38, 0.0000e+00, -3.4028e+38],
[-3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
as expected. We see the causal mask and in addition we see that the first column has high negative values.
Now let's run a softmax on the attention output corresonding to Softmax(QK^T)
assuming that QK^T
is 1s only.
print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0).softmax(-2)
tensor([[[[0.3333, 0.0000, 0.0000],
[0.3333, 0.5000, 0.0000],
[0.3333, 0.5000, 1.0000]]]])
As we can see we put equal weight on all input tokens for the output of the padded hidden states vector. This means the output of the padded hidden states vector is very much not 0.
FA-2 on the other hand just doesn't compute these outputs at all or forces them to be 0 which creates the difference.
Summary
Long story short, let's make sure to not compare outputs of padded hidden states. These states are moot no matter what and should not be used for anything.
It would be great to re-run the little experiment here but making sure that -100 is provided for padded out tokens.
@patrickvonplaten Would you recommend us using Flash Attention 2 then over the default attention until this bug fix lands?
I don't think there is a bug at all tbh. Padding tokens are expected to differ between FA2 and vanilla attention. Even when only comparing non-padding tokens there will be minor differences due to the different CUDA kernels being used (but they should not be as big as shown here: https://github.com/huggingface/transformers/issues/27050#issue-1960195010)
Generally, I always recommend using FA2 if you can use it
I don't think there is a bug at all tbh. Padding tokens are expected to differ between FA2 and vanilla attention. Even when only comparing non-padding tokens there will be minor differences due to the different CUDA kernels being used (but they should not be as big as shown here: #27050 (comment))
Generally, I always recommend using FA2 if you can use it
I agree. The padded part should not affect the training loss and inference result. In my experiments, training with FA2 but test with vanilla attention does not make any affects at all. But the creepy thing is, training and test with FA2 yields poor results (but the weights is ok if I switch to vanilla attention at test). I see many issues also report the test result discrepancy when using model.generate. Just a guess, maybe we should conduct a more in-depth investigation into the post-process in model.generate?
I think the reason for the discrepancy between FA-2 and non-FA-2 here comes solely from the fact that we're comparing padded output tensors and/or included padded hidden states vectors in our results. Padded hidden states vectors are useless/moot and should never influence a loss or be compared.
Let's explain a bit:
- Padded hidden states vectors are vectors that correspond to a sequence index
i
that is not attended to meaningattention_mask[i]
is 0. This corresponds to the outer-most left tokens here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) since we use left-padding or all tokens after2:
here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) when doing right padding.- One should never take padded hidden states vectors into account! One should never never compare padded hidden states vectors to each other because they are moot / useless and should never be used. This means when comparing the loss here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment), one has to add -100 to the labels indexes that correspond to padding tokens to make sure they don't influence the loss. See this issue as well. Similarly it doesn't make much sense to do this fix here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) and to compare the outputs of padded tokens because they are useless anyways.
- What is going on here?!
Let's look at a tiny code example that explains the behavior of non-FA2 code.
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask attention_mask = torch.tensor([[0, 1, 1]]) # left padding print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0))
We get
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38], [-3.4028e+38, 0.0000e+00, -3.4028e+38], [-3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
as expected. We see the causal mask and in addition we see that the first column has high negative values.
Now let's run a softmax on the attention output corresonding to
Softmax(QK^T)
assuming thatQK^T
is 1s only.print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0).softmax(-2)
tensor([[[[0.3333, 0.0000, 0.0000], [0.3333, 0.5000, 0.0000], [0.3333, 0.5000, 1.0000]]]])
As we can see we put equal weight on all input tokens for the output of the padded hidden states vector. This means the output of the padded hidden states vector is very much not 0.
FA-2 on the other hand just doesn't compute these outputs at all or forces them to be 0 which creates the difference.
Summary
Long story short, let's make sure to not compare outputs of padded hidden states. These states are moot no matter what and should not be used for anything.
It would be great to re-run the little experiment here but making sure that -100 is provided for padded out tokens.
So I understand what you are saying and agree, the padded tokens and hidden states should not be used at any point. However I disagree with your conclusion that no bug is necessarily present.
The example provided at the top of this thread does not have padding. If padding is being added and being used anywhere, that is happening in the Huggingface code. Moreover, the loss function we are reporting is the loss function by the Huggingface LLama2 model, again not something that we are writing. If there is a mistake in what we are doing, then we should be able to call out a specific line number in the script at the top of the page where a mistake is made, but I am really having a hard time finding one there. Otherwise whatever is causing the discrepancy would be part of either the Huggingface code, or the code distributed by Meta and hosted on Huggingface.
I think the reason for the discrepancy between FA-2 and non-FA-2 here comes solely from the fact that we're comparing padded output tensors and/or included padded hidden states vectors in our results. Padded hidden states vectors are useless/moot and should never influence a loss or be compared.
Let's explain a bit:
- Padded hidden states vectors are vectors that correspond to a sequence index
i
that is not attended to meaningattention_mask[i]
is 0. This corresponds to the outer-most left tokens here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) since we use left-padding or all tokens after2:
here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) when doing right padding.- One should never take padded hidden states vectors into account! One should never never compare padded hidden states vectors to each other because they are moot / useless and should never be used. This means when comparing the loss here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment), one has to add -100 to the labels indexes that correspond to padding tokens to make sure they don't influence the loss. See this issue as well. Similarly it doesn't make much sense to do this fix here: Difference in LlamaAttention & LlamaFlashAttention2 attn_outputĀ #27050 (comment) and to compare the outputs of padded tokens because they are useless anyways.
- What is going on here?!
Let's look at a tiny code example that explains the behavior of non-FA2 code.
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask attention_mask = torch.tensor([[0, 1, 1]]) # left padding print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0))
We get
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38], [-3.4028e+38, 0.0000e+00, -3.4028e+38], [-3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
as expected. We see the causal mask and in addition we see that the first column has high negative values.
Now let's run a softmax on the attention output corresonding to
Softmax(QK^T)
assuming thatQK^T
is 1s only.print(_prepare_4d_causal_attention_mask(attention_mask, (1, 3), attention_mask.float(), 0).softmax(-2)
tensor([[[[0.3333, 0.0000, 0.0000], [0.3333, 0.5000, 0.0000], [0.3333, 0.5000, 1.0000]]]])
As we can see we put equal weight on all input tokens for the output of the padded hidden states vector. This means the output of the padded hidden states vector is very much not 0.
FA-2 on the other hand just doesn't compute these outputs at all or forces them to be 0 which creates the difference.
Summary
Long story short, let's make sure to not compare outputs of padded hidden states. These states are moot no matter what and should not be used for anything.
It would be great to re-run the little experiment here but making sure that -100 is provided for padded out tokens.
Hi @patrickvonplaten , thanks for the detail explanation :) I agree that the attention output of using FA or not is different.
However, as we know that we are doing a linear projection for the attention output output = linear_proj(attn_output)
, which is essentially a matmul, output = matmul(attn_output, weight)
. So the output is indeed affected by the moot part
.
Trying to narrow down the problem:
It would be great to test: a) Training. @KyleMylonakisProtopia To make sure FA2 vs. no-FA2 influences training we need to make sure to add -100 to padded tokens as follows.
import argparse
import torch
import torch.backends.cudnn
import transformers
from transformers.models import llama
def main() -> None:
torch.backends.cudnn.deterministic = True
parser = argparse.ArgumentParser()
parser.add_argument("--use-flash-attention-2", action="store_true")
args = parser.parse_args()
use_flash_attention_2 = args.use_flash_attention_2
tokenizer = transformers.AutoTokenizer.from_pretrained(
"/models/huggingface/meta-llama/llama-2-7b-chat-hf", local_files_only=True, use_safetensors=True, device_map=torch.device("cuda")
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
text = "Hello world!"
tokenized_text = tokenizer(text)
tokenized_text = {key: torch.tensor(value).unsqueeze(dim=0).to(torch.device("cuda")) for key, value in tokenized_text.items()}
tokenized_text["labels"] = tokenized_text["input_ids"].clone()
+ tokenized_text["labels"] = torch.where(attention_mask == 0, -100, tokenized_text["labels"]) # make sure to not apply loss on padded tokens
torch.manual_seed(0)
model = llama.LlamaForCausalLM.from_pretrained(
"/models/huggingface/meta-llama/llama-2-7b-chat-hf",
local_files_only=True,
use_safetensors=True,
device_map=torch.device("cuda"),
use_flash_attention_2=use_flash_attention_2,
torch_dtype=torch.bfloat16,
)
assert isinstance(model, llama.LlamaForCausalLM)
model.eval()
for param in model.parameters():
param.requires_grad = False
model.model.layers[0].train()
for param in model.model.layers[0].parameters():
param.requires_grad = True
optim = torch.optim.AdamW(model.parameters())
torch.manual_seed(0)
for i in range(10):
output = model(**tokenized_text)
loss = output["loss"]
if i in (0, 9):
print(loss)
loss.backward()
optim.step()
optim.zero_grad()
if __name__ == "__main__":
main()
In addition it won't be enough to loss at the loss curve and say there is a bug if they differ. They will surely differ since the backward method of FA2 is very different to no FA2. We need to actually train Llama on a bit of data and see how quickly the models learn depending on whether FA2 or no FA2 is implemented.
b) It would be great to have a fully reproducible code snippet where we're clearly seeing different results between FA2 and no FA2 for generate
. This should be easy to do. Just run both models with the same seed and generate and find an example where they significantly differ or where one is clearly better then the other one.
However, as we know that we are doing a linear projection for the attention output output = linear_proj(attn_output), which is essentially a matmul, output = matmul(attn_output, weight). So the output is indeed affected by the moot part.
This is not really correct because linear_proj
is "seq-len"-independent. Image that the Softmax(QK^T) output is as follows:
attn_output = [vec1, vec2, vec3, pad_vec, pad_vec, vec4]
Now doing:
output = matmul([vec1, vec2, vec3, pad_vec, pad_vec, vec4], weight)
will give you again:
[new_vec1, new_vec2, new_vec3, new_pad_vec, new_pad_vec, new_vec4]
whereby importantly new_pad_vec
did not influence the computation of new_vec1
at all. The linear projection is only applied over the feature dimension not the seq dimensions, hence you could also do the following and get the same results:
for i in range(6)
new_vec_{i} = matmul(vec_{i}, weight)
To make some progress here it would be really really great if someone could provide a reproducible code snippet of either a) or b)
This example of cckao is a great example of b): https://github.com/huggingface/transformers/issues/27050#issuecomment-1780424209
But I don't think it's a bug and simply due to different CUDA kernels being used. Note how similar the generations are. You would find similar differences just by running the same algorithm on a different hardware. If there would be a bug with the attention mask, the differences would be much starker.
Just ran with the additional line of code you suggested and unfortunately there was no change in the behavior. The discrepancy remains exactly as it was.
You mean the implementation of the backwards for FA2 is very different to the implementation of the method without FA2. The implementations with and without FA2 are both exact in the sense they are not performing any numerical approximations of the derivative. The sources of error would be truncation error and the non-associativity and commutativity of floating point numbers. Now it could be that that very small error accumulates rapidly due to lack of stability. If that were the case the decreasing the lr, say to 5e-5, and running out to 50 iterations should diminish the discrepancy. However when I do that I see even starker differences at iteration 50.
python flash_attn_non_determinism.py
tensor(5.6589, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(2.5236, device='cuda:0', grad_fn=<NllLossBackward0>)
python flash_attn_non_determinism.py --use-flash-attention-2
tensor(5.6612, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.4144, device='cuda:0', grad_fn=<NllLossBackward0>)
We know they modes are training differently with and without FA2 already because that's why we made this ticket in the first place: we were not able to reproduce the same results that we had previously established without FA2 after enabling it.
Thanks for re-running the training script @KyleMylonakisProtopia ! And in your training experiments, using FA2 doesn't give sensible results where as not using FA2 for training does? Also, it seems like both work correctly in inference no?
=> So could it be that there is then a bug with FA2 for training only?
Have we looked at how the gradients of attention vs. flash attention 2 are backpropagating?
This example of cckao is a great example of b): #27050 (comment)
But I don't think it's a bug and simply due to different CUDA kernels being used. Note how similar the generations are. You would find similar differences just by running the same algorithm on a different hardware. If there would be a bug with the attention mask, the differences would be much starker.
The difference actually becomes quite sensible after the first different token due to the nature of autoregressive models. If the difference is due to different CUDA kernels and we cannot fix it, that really limited the application of FA2 to pretrained models.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@patrickvonplaten This issue is still relevant. Have you been able to look at the gradients?
System Info
transformers
version: 4.34.1Who can help?
@ArthurZucker and @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
We notice
LlamaFlashAttention2._flash_attention_forward
returns a differentattn_output
thanLlamaAttention
computes.flash_attn_non_determinism.py
:Expected behavior
I am not expecting the magnitude of the difference between the 2 implementations. A difference of
0.1267
compared to0.3542
seems very large.