huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.74k stars 25.54k forks source link

Transformer models are not deterministic when using Flash Attention 2 #31787

Open YunfanZhang42 opened 4 days ago

YunfanZhang42 commented 4 days ago

System Info

Who can help?

@ArthurZucker @stevhliu

Information

Tasks

Reproduction

  1. Install Flash Attention 2. I am using version 2.5.6 but this should not matter.
  2. Run export CUBLAS_WORKSPACE_CONFIG=:4096:8 as required by the PyTorch Reproducibility Guide.
  3. Run the following script
    
    import random
    import torch
    import numpy as np
    from transformers import BartForConditionalGeneration, BartTokenizer

def test_consistency(attn_implementation="flash_attention_2"):

Load the model and tokenizer

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base", torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, device_map="cuda:0")

# Define the prompt
prompt = "My favourite condiment is"

# Tokenize the input and send to appropriate device
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda:0")

# Disable gradient calculations
with torch.no_grad():
    # Store the first output to compare with subsequent outputs
    first_output = None
    consistent = True

    # Perform inference 1000 times
    for i in range(1000):
        # Get the model's output (logits)
        outputs = model(**model_inputs)

        # Get logits of the last token from the output
        logits = outputs.logits[:, -1, :]

        # If it's the first run, store the output logits
        if first_output is None:
            first_output = logits
        else:
            # Compare current output with the first output
            if not torch.equal(first_output, logits):
                consistent = False
                break

    # Output whether all runs produced the same probabilities
    print(f"Using attention implementation {attn_implementation}, Consistent: {consistent}")

if name == "main": torch.manual_seed(42) np.random.seed(42) random.seed(42) torch.use_deterministic_algorithms(True, warn_only=True) torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True

test_consistency("flash_attention_2")
test_consistency("sdpa")
test_consistency("eager")

### Expected behavior

Since we have closely followed the PyTorch reproducibility guide, we expect to see:

Using attention implementation flash_attention_2, Consistent: True Using attention implementation sdpa, Consistent: True Using attention implementation eager, Consistent: True

In reality, we would get:

Using attention implementation flash_attention_2, Consistent: False Using attention implementation sdpa, Consistent: True Using attention implementation eager, Consistent: True


Based on this experiment, it seems Flash Attention 2 is not deterministic in the forward pass, and according to https://github.com/Dao-AILab/flash-attention/issues/414, Flash Attention 2 would not be deterministic for the backward pass as well, so this also affects training. 

It's worth noting that PyTorch `sdpa` implementation may also select a non-deterministic execution path depending on the input dimension, but it would throw the following error/warning when `torch.use_deterministic_algorithms` is set to `True`, so it would not fail silently.

UserWarning: Memory Efficient attention defaults to a non-deterministic algorithm. To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False). (Triggered internally at ../aten/src/ATen/native/transformers/cuda/attention_backward.cu:449.)



Therefore, I think we should throw a similar error when the `attn_implementation` is set to `flash_attention_2` 
 and `use_deterministic_algorithms` is set to `True`. At least, this behavior should be documented in Transformer docs.

Finally, thank you for your contribution to the deep learning and open source community. Please let me know how I can contribute here.
Varma0604 commented 4 days ago

you’re running a script to test the consistency of different attention implementations using PyTorch and Flash Attention 2. While sdpa and eager implementations work as expected, flash_attention_2 is giving inconsistent results despite following the PyTorch reproducibility guidelines.

Here’s what you can do next:

1.  Double-check your settings: Ensure all flags and settings for determinism are correctly applied. Small misconfigurations can cause issues.
2.  Review documentation: Check the Flash Attention 2 documentation for any notes on determinism and if there are specific configurations needed.
3.  Update your tools: Make sure you’re using the latest versions of PyTorch and Flash Attention 2, as updates may resolve your issue.
4.  Community support: Post your issue on PyTorch forums or Flash Attention 2’s GitHub repository. The community might have encountered and solved similar issues.
5.  Contribute: If you find that Flash Attention 2 is inherently non-deterministic, consider suggesting a change or adding a warning in the documentation or code to help others.
LysandreJik commented 4 days ago

Is this AI-generated @Varma0604 ?


Thanks for your report @YunfanZhang42! @ArthurZucker will be able to help once he's back from his holiday (next week). Thank you for your patience!

YunfanZhang42 commented 3 days ago

@LysandreJik Thank you for your reply. No need to rush here as this is not an urgent issue. And yes, I think @Varma0604 is spam posting using LLMs.

I did a few experiments on decoder only models for both the forward and the backward pass and the results are interesting. Here are the steps to reproduce:

  1. Install Flash Attention 2. I am using version 2.5.6
  2. Run export CUBLAS_WORKSPACE_CONFIG=:4096:8 as required by the PyTorch Reproducibility Guide.
  3. Run the following script
    
    import torch
    import random
    import torch
    import numpy as np
    from transformers import AutoModelForCausalLM, AutoTokenizer

def test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="flash_attention_2"):

Load the model and tokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, device_map="cuda:0")

# Define the prompt
prompt = "My favourite condiment is"

# Tokenize the input and send to appropriate device
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda:0")

# Perform multiple runs and check for consistency
output_consistency = True
gradient_consistency = True
first_output = None
first_gradients = None

# Run the model 10 times
for _ in range(10):
    # Ensure model is in training mode
    model.train()

    # Clear previous gradients
    model.zero_grad()

    # Get the model's output (logits)
    outputs = model(**model_inputs)

    # Use the sum of logits as a simple loss
    loss = outputs.logits.sum()

    # Compute gradients
    loss.backward()

    # Extract gradients and store them on CPU to avoid OOM
    current_gradients = {name: param.grad.cpu() for name, param in model.named_parameters() if param.grad is not None}

    # If it's the first run, store the output logits
    if first_output is None:
        first_output = outputs.logits
    else:
        # Compare current output with the first output
        if not torch.equal(first_output, outputs.logits):
            output_consistency = False

    # If it's the first run, store the gradients
    if first_gradients is None:
        first_gradients = current_gradients
    else:
        # Compare current gradients with the first gradients
        for name, grad in current_gradients.items():
            if not torch.equal(grad, first_gradients[name]):
                gradient_consistency = False

# Print the results
print(f"Model: {model_name}, Attention Implementation: {attn_implementation}, Output Consistency: {output_consistency}, Gradient Consistency: {gradient_consistency}")

if name == "main": torch.manual_seed(42) np.random.seed(42) random.seed(42) torch.use_deterministic_algorithms(True, warn_only=True) torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True

test_consistency(model_name="gpt2", attn_implementation="flash_attention_2")
# GPT2 does not support sdpa implementations
# test_consistency(model_name="gpt2", attn_implementation="sdpa")
test_consistency(model_name="gpt2", attn_implementation="eager")

test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="flash_attention_2")
test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="sdpa")
test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="eager")

test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="flash_attention_2")
test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="sdpa")
test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="eager")
What I get is

Model: gpt2, Attention Implementation: flash_attention_2, Output Consistency: False, Gradient Consistency: False Model: gpt2, Attention Implementation: eager, Output Consistency: False, Gradient Consistency: False Model: mistralai/Mistral-7B-v0.1, Attention Implementation: flash_attention_2, Output Consistency: True, Gradient Consistency: True Model: mistralai/Mistral-7B-v0.1, Attention Implementation: sdpa, Output Consistency: True, Gradient Consistency: True Model: mistralai/Mistral-7B-v0.1, Attention Implementation: eager, Output Consistency: True, Gradient Consistency: True Model: meta-llama/Meta-Llama-3-8B, Attention Implementation: flash_attention_2, Output Consistency: True, Gradient Consistency: True Model: meta-llama/Meta-Llama-3-8B, Attention Implementation: sdpa, Output Consistency: True, Gradient Consistency: True Model: meta-llama/Meta-Llama-3-8B, Attention Implementation: eager, Output Consistency: True, Gradient Consistency: True



So GPT-2 is not deterministic under both Flash Attention 2 and the default attention implementation. Mistral and Llama 3 demostrate deterministic behavior regardless of the type of the attention that is used. 

I also checked the [Flash Attention 2 documentation](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#how-to-use-flashattention), and it seems to me that after Flash Attention v2.5.0, the foward pass should always be deterministic, and the backward could be made deterministic by passing `deterministic=True`, which we did not do for either [Mistral](https://github.com/huggingface/transformers/blob/048f599f3506e57e0a595b455d9d2834c8d45023/src/transformers/models/mistral/modeling_mistral.py#L481) or [Llama](https://github.com/huggingface/transformers/blob/048f599f3506e57e0a595b455d9d2834c8d45023/src/transformers/models/llama/modeling_llama.py#L517) IMHO. 

So my suspicions are either (1) BART and GPT-2 are FP32 models, but we are using BF16 mode, which could make the numerical stability issues more pronounced or (2) transformers' attention implementation for certain models might be problematic and could trigger unexpected behavior under some conditions. 

Let me know how I can help further, and thanks a lot!