Open rationalism opened 11 months ago
The error states very clearly that you can only use flash attention with bf16 and fp16. This is not an accelerate issue, you must use one of those for mixed precision
@muellerzr Transformer Engine currently includes support for Flash Attention, see the documentation here:
but the Accelerate Transformer Engine integration doesn't include TE's attention classes:
https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/transformer_engine.py
Will see if we can drop in a replacement, but with TE it’s very model specific and more often than not these layers are not drop in with how transformers models have been made, so very doubtful.
If you find it does work, feel free to open a PR :)
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
Run the
run_clm_no_trainer.py
script with the provided config on a model that supports Flash Attention, and turn Flash Attention on with the flaguse_flash_attention_2=True
:Eg. here is a call I used:
accelerate launch run_clm_no_trainer.py --train_file blog_train.json --validation_file blog_val.json --model_name_or_path facebook/opt-1.3b --per_device_train_batch_size 12 --per_device_eval_batch_size 12 --learning_rate 1e-5 --num_train_epochs 1 --gradient_accumulation_steps 4 --output_dir blog_opt --block_size 1024 --checkpointing_steps 400000 --trust_remote_code true --report_to wandb --eval_steps 100 --with_tracking
Expected behavior
Training should run. Instead, it crashes with this error:
Traceback (most recent call last): File "/home/alyssa/lm_fun/run_clm.py", line 769, in
main()
File "/home/alyssa/lm_fun/run_clm.py", line 675, in main
outputs = model(batch)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(args, kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1818, in forward
loss = self.module(*inputs, *kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(args, kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 1129, in forward
outputs = self.model.decoder(
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(args, kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 885, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, *kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(args, kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint
return CheckpointFunction.apply(function, preserve, args)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(args, kwargs) # type: ignore[misc]
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 230, in forward
outputs = run_function(args)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(args, kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 536, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(args, kwargs)
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 371, in forward
attn_output = self._flash_attention_forward(
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/transformers/models/opt/modeling_opt.py", line 416, in _flash_attention_forward
attn_output_unpad = flash_attn_varlen_func(
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 906, in flash_attn_varlen_func
return FlashAttnVarlenFunc.apply(
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 496, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 79, in _flash_attn_varlen_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
[2023-11-14 18:06:43,085] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 559871 closing signal SIGTERM