huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.71k stars 1.22k forks source link

Flash Attention 2 errors with Falcon models in `SFTTrainer` #832

Closed lewtun closed 12 months ago

lewtun commented 1 year ago

Now that Flash Attention 2 is natively supported in transformers for Llama / Falcon models, I tried to run the sft_trainer.py example and am running into various errors (reproduced below). I am initialising the models by adding the use_flash_attention_2=True flag in the from_pretrained() method as follows:

@@ -91,6 +93,7 @@ model = AutoModelForCausalLM.from_pretrained(
     trust_remote_code=script_args.trust_remote_code,
     torch_dtype=torch_dtype,
     use_auth_token=script_args.use_auth_token,
+    use_flash_attention_2=True
 )

I have also tried both packing=True and packing=False, but the errors remain in both cases.

Errors and steps to reproduce provided below.

tiiuae/falcon-rw-1b

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml examples/scripts/sft_trainer.py --batch_size 1 --model_name tiiuae/falcon-rw-1b

Gives:

Traceback (most recent call last):
  File "/fsx/lewis/git/trl/examples/scripts/sft_trainer.py", line 139, in <module>
    trainer.train()
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
    outputs = block(
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    trainer.train()
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return forward_call(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 794, in forward
    return inner_training_loop(
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop
    return func(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1279, in forward
    return inner_training_loop(
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/trainer.py", line 1892, in _inner_training_loop
    attn_outputs = self.self_attention(
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    tr_loss_step = self.training_step(model, inputs)
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/trainer.py", line 2776, in training_step
    transformer_outputs = self.transformer(
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 607, in forward
    tr_loss_step = self.training_step(model, inputs)
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/trainer.py", line 2776, in training_step
    return forward_call(*args, **kwargs)
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/models/falcon/modeling_falcon.py", line 1163, in forward
    loss = self.compute_loss(model, inputs)
      File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/transformers/trainer.py", line 2801, in compute_loss
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")

tiiuae/falcon-7b

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml examples/scripts/sft_trainer.py --batch_size 1 --model_name tiiuae/falcon-7b

Gives

  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 47, in _flash_attn_forward
      File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 47, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
  File "/fsx/lewis/miniconda/envs/h4/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 47, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
    TypeErrorout, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(        : 
TypeErrorout, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(fwd(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: Optional[torch.Tensor], arg4: float, arg5: float, arg6: bool, arg7: int, arg8: int, arg9: bool, arg10: Optional[torch.Generator]) -> List[torch.Tensor]
lewtun commented 1 year ago

Note that Llama 2 works fine with ZeRO-2:

diff --git a/examples/scripts/sft_trainer.py b/examples/scripts/sft_trainer.py
index 4dac022..217f97b 100644
--- a/examples/scripts/sft_trainer.py
+++ b/examples/scripts/sft_trainer.py
@@ -91,6 +91,7 @@ model = AutoModelForCausalLM.from_pretrained(
     trust_remote_code=script_args.trust_remote_code,
     torch_dtype=torch_dtype,
     use_auth_token=script_args.use_auth_token,
+    use_flash_attention_2=True
 )

 # Step 2: Load the dataset
@@ -110,6 +111,7 @@ training_args = TrainingArguments(
     save_total_limit=script_args.save_total_limit,
     push_to_hub=script_args.push_to_hub,
     hub_model_id=script_args.hub_model_id,
+    bf16=True
 )
ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/sft_trainer.py --batch_size 1 --model_name meta-llama/Llama-2-7b-hf --gradient_accumulation_steps 1
younesbelkada commented 1 year ago

Will look into it, thanks a lot for reporting @lewtun !!

younesbelkada commented 12 months ago

Hi @lewtun i managed to repro the issue, the fix was quite straightforward: https://github.com/huggingface/transformers/pull/26852 Regarding the first issue, note that FA-2 + attention bias is not supported through Flash Attention interface, However: https://github.com/Dao-AILab/flash-attention/pull/540 should add alibi support for FA-2, let's wait for that PR to be merged and add FA-2 support for models that use alibi attention bias

RonanKMcGovern commented 12 months ago

Great work, looking forward to the merge @younesbelkada .

Also: a) I guess alibi support isn't an issue specifically for Falcon, which is RoPE?

b) Is this issue related? https://github.com/huggingface/transformers/issues/26829