huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.92k stars 964 forks source link

Accelerate integration with Transformer Engine crashes when using FlashAttention #2154

Open rationalism opened 11 months ago

rationalism commented 11 months ago

System Info

Accelerate: 0.24.1
OS: Ubuntu 22.04
Python: 3.10
NumPy: 1.26.0
Torch: 2.1.0

Accelerate configuration:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 4
  gradient_clipping: 1.0
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: false
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: fp8
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Information

Tasks

Reproduction

Run the run_clm_no_trainer.py script with the provided config on a model that supports Flash Attention, and turn Flash Attention on with the flag use_flash_attention_2=True:

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=args.trust_remote_code,
        use_flash_attention_2=True
    )

Eg. here is a call I used:

accelerate launch run_clm_no_trainer.py --train_file blog_train.json --validation_file blog_val.json --model_name_or_path facebook/opt-1.3b --per_device_train_batch_size 12 --per_device_eval_batch_size 12 --learning_rate 1e-5 --num_train_epochs 1 --gradient_accumulation_steps 4 --output_dir blog_opt --block_size 1024 --checkpointing_steps 400000 --trust_remote_code true --report_to wandb --eval_steps 100 --with_tracking

Expected behavior

Training should run. Instead, it crashes with this error:

Traceback (most recent call last): File "/home/alyssa/lm_fun/run_clm.py", line 769, in main() File "/home/alyssa/lm_fun/run_clm.py", line 675, in main outputs = model(batch) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1818, in forward loss = self.module(*inputs, *kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl result = forward_call(*args, kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 1129, in forward outputs = self.model.decoder( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl result = forward_call(args, kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 885, in forward layer_outputs = self._gradient_checkpointing_func( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner return torch._dynamo.disable(fn, recursive)(*args, kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn return fn(*args, *kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner return fn(args, kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint return CheckpointFunction.apply(function, preserve, args) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(args, kwargs) # type: ignore[misc] File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 230, in forward outputs = run_function(args) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl result = forward_call(*args, kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 536, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl result = forward_call(args, kwargs) File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 371, in forward attn_output = self._flash_attention_forward( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 416, in _flash_attention_forward attn_output_unpad = flash_attn_varlen_func( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 906, in flash_attn_varlen_func return FlashAttnVarlenFunc.apply( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 496, in forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 79, in _flash_attn_varlen_forward out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( RuntimeError: FlashAttention only support fp16 and bf16 data type [2023-11-14 18:06:43,085] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 559871 closing signal SIGTERM

muellerzr commented 11 months ago

The error states very clearly that you can only use flash attention with bf16 and fp16. This is not an accelerate issue, you must use one of those for mixed precision

rationalism commented 11 months ago

@muellerzr Transformer Engine currently includes support for Flash Attention, see the documentation here:

https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html?highlight=flash#transformer_engine.pytorch.DotProductAttention

but the Accelerate Transformer Engine integration doesn't include TE's attention classes:

https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/transformer_engine.py

muellerzr commented 11 months ago

Will see if we can drop in a replacement, but with TE it’s very model specific and more often than not these layers are not drop in with how transformers models have been made, so very doubtful.

muellerzr commented 11 months ago

If you find it does work, feel free to open a PR :)