Closed lewtun closed 12 months ago
Note that Llama 2 works fine with ZeRO-2:
diff --git a/examples/scripts/sft_trainer.py b/examples/scripts/sft_trainer.py
index 4dac022..217f97b 100644
--- a/examples/scripts/sft_trainer.py
+++ b/examples/scripts/sft_trainer.py
@@ -91,6 +91,7 @@ model = AutoModelForCausalLM.from_pretrained(
trust_remote_code=script_args.trust_remote_code,
torch_dtype=torch_dtype,
use_auth_token=script_args.use_auth_token,
+ use_flash_attention_2=True
)
# Step 2: Load the dataset
@@ -110,6 +111,7 @@ training_args = TrainingArguments(
save_total_limit=script_args.save_total_limit,
push_to_hub=script_args.push_to_hub,
hub_model_id=script_args.hub_model_id,
+ bf16=True
)
ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/sft_trainer.py --batch_size 1 --model_name meta-llama/Llama-2-7b-hf --gradient_accumulation_steps 1
Will look into it, thanks a lot for reporting @lewtun !!
Hi @lewtun i managed to repro the issue, the fix was quite straightforward: https://github.com/huggingface/transformers/pull/26852 Regarding the first issue, note that FA-2 + attention bias is not supported through Flash Attention interface, However: https://github.com/Dao-AILab/flash-attention/pull/540 should add alibi support for FA-2, let's wait for that PR to be merged and add FA-2 support for models that use alibi attention bias
Great work, looking forward to the merge @younesbelkada .
Also: a) I guess alibi support isn't an issue specifically for Falcon, which is RoPE?
b) Is this issue related? https://github.com/huggingface/transformers/issues/26829
Now that Flash Attention 2 is natively supported in
transformers
for Llama / Falcon models, I tried to run thesft_trainer.py
example and am running into various errors (reproduced below). I am initialising the models by adding theuse_flash_attention_2=True
flag in thefrom_pretrained()
method as follows:I have also tried both
packing=True
andpacking=False
, but the errors remain in both cases.Errors and steps to reproduce provided below.
tiiuae/falcon-rw-1b
Gives:
tiiuae/falcon-7b
Gives