lm-sys / FastChat

An open platform for training, serving, and evaluating large language models. Release repo for Vicuna and Chatbot Arena.
Apache License 2.0
36.85k stars 4.54k forks source link

How to fine tune vicuna-7b with A40 #1296

Open yqh984638220 opened 1 year ago

yqh984638220 commented 1 year ago

How to fine tune vicuna-7b with A40

gabinguo commented 1 year ago
FlashAttention backward for head dim > 64 requires A100 or H100
GPUs as the implementation needs a large amount of shared memory.

This might be related.. got this error while using train_mem.py with A40s..

jhu10 commented 1 year ago

same problem. running train_mem.py using following args: torchrun --nproc_per_node=2 --master_port=20001 /data/ljn/Vicuna-13B/model/FastChat/fastchat/train/train_mem.py \ --model_name_or_path /data/ljn/Vicuna-13B/model/FastChat/data/hf \ --data_path /data/ljn/Vicuna-13B/model/FastChat/playground/data/leecode_new.json \ --bf16 True \ --output_dir /data/ljn/Vicuna-13B/model/FastChat/output \ --num_train_epochs 3 \ --per_device_train_batch_size 2 \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 16 \ --evaluation_strategy "no" \ --save_strategy "steps" \ --save_steps 1200 \ --save_total_limit 10 \ --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --fsdp "full_shard auto_wrap" \ --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ --tf32 True \ --model_max_length 2048 \ --gradient_checkpointing True \ --lazy_preprocess True

Error: Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████| 3/3 [03:19<00:00, 66.41s/it] wandb: (1) Create a W&B account wandb: (2) Use an existing W&B account wandb: (3) Don't visualize my results wandb: Enter your choice: 3 wandb: You chose "Don't visualize my results" wandb: Tracking run with wandb version 0.15.3 wandb: W&B syncing is set to offline in this directory. wandb: Run wandb online or set WANDB_MODE=online to enable cloud syncing. 0%| | 0/537 [00:00<?, ?it/s]Traceback (most recent call last): File "/data/ljn/Vicuna-13B/model/FastChat/fastchat/train/train_mem.py", line 13, in train() File "/data/ljn/Vicuna-13B/model/FastChat/fastchat/train/train.py", line 263, in train trainer.train() File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/transformers/trainer.py", line 1662, in train return inner_training_loop( File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/transformers/trainer.py", line 1927, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/transformers/trainer.py", line 2717, in training_step loss.backward() File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward torch.autograd.backward( File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/autograd/init.py", line 200, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply return user_fn(self, args) File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 157, in backward torch.autograd.backward(outputs_with_grad, args_with_grad) File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/autograd/init.py", line 200, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply return user_fn(self, args) File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 75, in backward _flash_attn_backward( File "/data/ljn/Vicuna-13B/lib/python3.9/site-packages/flash_attn/flash_attn_interface.py", line 42, in _flash_attnbackward , , , softmax_d = flash_attn_cuda.bwd( RuntimeError: FlashAttention backward for head dim > 64 requires A100 or H100 GPUs as the implementation needs a large amount of shared memory.

I used 2 * A40 on my machine, is there any suggestions to solve this?

Ted8000 commented 1 year ago

maybe you can remove flash_attn() ,

image
prateeky2806 commented 1 year ago

I am facing this similar issues, anyone figured out how to solve this?

Hzzhang-nlp commented 1 year ago

我似乎也遇到了同样的问题 image

Hzzhang-nlp commented 1 year ago

How to solve it image

aduan commented 1 year ago

@Hzzhang-nlp if cuda is 12.x, can install pytorch 12.1 from nightly and install flash-attention from source

pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121

git clone https://github.com/HazyResearch/flash-attention.git
python setup.py install
Len-Li commented 1 year ago

It works for me. But still OOM for a single A6000 finetuning.

surak commented 1 year ago

We can't do much about this, I fear. If the gpu has not enough memory, it has not enough memory...