Open yqh984638220 opened 1 year ago
FlashAttention backward for head dim > 64 requires A100 or H100
GPUs as the implementation needs a large amount of shared memory.
This might be related.. got this error while using train_mem.py with A40s..
same problem. running train_mem.py using following args: torchrun --nproc_per_node=2 --master_port=20001 /data/ljn/Vicuna-13B/model/FastChat/fastchat/train/train_mem.py \ --model_name_or_path /data/ljn/Vicuna-13B/model/FastChat/data/hf \ --data_path /data/ljn/Vicuna-13B/model/FastChat/playground/data/leecode_new.json \ --bf16 True \ --output_dir /data/ljn/Vicuna-13B/model/FastChat/output \ --num_train_epochs 3 \ --per_device_train_batch_size 2 \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 16 \ --evaluation_strategy "no" \ --save_strategy "steps" \ --save_steps 1200 \ --save_total_limit 10 \ --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --fsdp "full_shard auto_wrap" \ --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ --tf32 True \ --model_max_length 2048 \ --gradient_checkpointing True \ --lazy_preprocess True
Error:
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████| 3/3 [03:19<00:00, 66.41s/it]
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: Tracking run with wandb version 0.15.3
wandb: W&B syncing is set to offline
in this directory.
wandb: Run wandb online
or set WANDB_MODE=online to enable cloud syncing.
0%| | 0/537 [00:00<?, ?it/s]Traceback (most recent call last):
File "/data/ljn/Vicuna-13B/model/FastChat/fastchat/train/train_mem.py", line 13, in
I used 2 * A40 on my machine, is there any suggestions to solve this?
maybe you can remove flash_attn() ,
I am facing this similar issues, anyone figured out how to solve this?
我似乎也遇到了同样的问题
How to solve it
@Hzzhang-nlp if cuda is 12.x, can install pytorch 12.1 from nightly and install flash-attention from source
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
git clone https://github.com/HazyResearch/flash-attention.git
python setup.py install
It works for me. But still OOM for a single A6000 finetuning.
We can't do much about this, I fear. If the gpu has not enough memory, it has not enough memory...
How to fine tune vicuna-7b with A40