Open YunfanZhang42 opened 4 days ago
you’re running a script to test the consistency of different attention implementations using PyTorch and Flash Attention 2. While sdpa and eager implementations work as expected, flash_attention_2 is giving inconsistent results despite following the PyTorch reproducibility guidelines.
Here’s what you can do next:
1. Double-check your settings: Ensure all flags and settings for determinism are correctly applied. Small misconfigurations can cause issues.
2. Review documentation: Check the Flash Attention 2 documentation for any notes on determinism and if there are specific configurations needed.
3. Update your tools: Make sure you’re using the latest versions of PyTorch and Flash Attention 2, as updates may resolve your issue.
4. Community support: Post your issue on PyTorch forums or Flash Attention 2’s GitHub repository. The community might have encountered and solved similar issues.
5. Contribute: If you find that Flash Attention 2 is inherently non-deterministic, consider suggesting a change or adding a warning in the documentation or code to help others.
Is this AI-generated @Varma0604 ?
Thanks for your report @YunfanZhang42! @ArthurZucker will be able to help once he's back from his holiday (next week). Thank you for your patience!
@LysandreJik Thank you for your reply. No need to rush here as this is not an urgent issue. And yes, I think @Varma0604 is spam posting using LLMs.
I did a few experiments on decoder only models for both the forward and the backward pass and the results are interesting. Here are the steps to reproduce:
export CUBLAS_WORKSPACE_CONFIG=:4096:8
as required by the PyTorch Reproducibility Guide.
import torch
import random
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
def test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="flash_attention_2"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, device_map="cuda:0")
# Define the prompt
prompt = "My favourite condiment is"
# Tokenize the input and send to appropriate device
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda:0")
# Perform multiple runs and check for consistency
output_consistency = True
gradient_consistency = True
first_output = None
first_gradients = None
# Run the model 10 times
for _ in range(10):
# Ensure model is in training mode
model.train()
# Clear previous gradients
model.zero_grad()
# Get the model's output (logits)
outputs = model(**model_inputs)
# Use the sum of logits as a simple loss
loss = outputs.logits.sum()
# Compute gradients
loss.backward()
# Extract gradients and store them on CPU to avoid OOM
current_gradients = {name: param.grad.cpu() for name, param in model.named_parameters() if param.grad is not None}
# If it's the first run, store the output logits
if first_output is None:
first_output = outputs.logits
else:
# Compare current output with the first output
if not torch.equal(first_output, outputs.logits):
output_consistency = False
# If it's the first run, store the gradients
if first_gradients is None:
first_gradients = current_gradients
else:
# Compare current gradients with the first gradients
for name, grad in current_gradients.items():
if not torch.equal(grad, first_gradients[name]):
gradient_consistency = False
# Print the results
print(f"Model: {model_name}, Attention Implementation: {attn_implementation}, Output Consistency: {output_consistency}, Gradient Consistency: {gradient_consistency}")
if name == "main": torch.manual_seed(42) np.random.seed(42) random.seed(42) torch.use_deterministic_algorithms(True, warn_only=True) torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
test_consistency(model_name="gpt2", attn_implementation="flash_attention_2")
# GPT2 does not support sdpa implementations
# test_consistency(model_name="gpt2", attn_implementation="sdpa")
test_consistency(model_name="gpt2", attn_implementation="eager")
test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="flash_attention_2")
test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="sdpa")
test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="eager")
test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="flash_attention_2")
test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="sdpa")
test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="eager")
What I get is
Model: gpt2, Attention Implementation: flash_attention_2, Output Consistency: False, Gradient Consistency: False Model: gpt2, Attention Implementation: eager, Output Consistency: False, Gradient Consistency: False Model: mistralai/Mistral-7B-v0.1, Attention Implementation: flash_attention_2, Output Consistency: True, Gradient Consistency: True Model: mistralai/Mistral-7B-v0.1, Attention Implementation: sdpa, Output Consistency: True, Gradient Consistency: True Model: mistralai/Mistral-7B-v0.1, Attention Implementation: eager, Output Consistency: True, Gradient Consistency: True Model: meta-llama/Meta-Llama-3-8B, Attention Implementation: flash_attention_2, Output Consistency: True, Gradient Consistency: True Model: meta-llama/Meta-Llama-3-8B, Attention Implementation: sdpa, Output Consistency: True, Gradient Consistency: True Model: meta-llama/Meta-Llama-3-8B, Attention Implementation: eager, Output Consistency: True, Gradient Consistency: True
So GPT-2 is not deterministic under both Flash Attention 2 and the default attention implementation. Mistral and Llama 3 demostrate deterministic behavior regardless of the type of the attention that is used.
I also checked the [Flash Attention 2 documentation](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#how-to-use-flashattention), and it seems to me that after Flash Attention v2.5.0, the foward pass should always be deterministic, and the backward could be made deterministic by passing `deterministic=True`, which we did not do for either [Mistral](https://github.com/huggingface/transformers/blob/048f599f3506e57e0a595b455d9d2834c8d45023/src/transformers/models/mistral/modeling_mistral.py#L481) or [Llama](https://github.com/huggingface/transformers/blob/048f599f3506e57e0a595b455d9d2834c8d45023/src/transformers/models/llama/modeling_llama.py#L517) IMHO.
So my suspicions are either (1) BART and GPT-2 are FP32 models, but we are using BF16 mode, which could make the numerical stability issues more pronounced or (2) transformers' attention implementation for certain models might be problematic and could trigger unexpected behavior under some conditions.
Let me know how I can help further, and thanks a lot!
System Info
transformers
version: 4.41.2Who can help?
@ArthurZucker @stevhliu
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
export CUBLAS_WORKSPACE_CONFIG=:4096:8
as required by the PyTorch Reproducibility Guide.def test_consistency(attn_implementation="flash_attention_2"):
Load the model and tokenizer
if name == "main": torch.manual_seed(42) np.random.seed(42) random.seed(42) torch.use_deterministic_algorithms(True, warn_only=True) torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
Using attention implementation flash_attention_2, Consistent: True Using attention implementation sdpa, Consistent: True Using attention implementation eager, Consistent: True
Using attention implementation flash_attention_2, Consistent: False Using attention implementation sdpa, Consistent: True Using attention implementation eager, Consistent: True
UserWarning: Memory Efficient attention defaults to a non-deterministic algorithm. To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False). (Triggered internally at ../aten/src/ATen/native/transformers/cuda/attention_backward.cu:449.)