lm-sys / FastChat

An open platform for training, serving, and evaluating large language models. Release repo for Vicuna and Chatbot Arena.
Apache License 2.0
36.41k stars 4.48k forks source link

Training Vicuna-13B with 8xA100(40GB) #1294

Closed MoaazZaki closed 1 year ago

MoaazZaki commented 1 year ago

Hi!

Quick question, is it possible to train Vicuna-13B with 8xA100(40GB) (320GB)?

If yes, will it need any special setup different than this one ?

torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/train_mem.py \
    --model_name_or_path ~/model_weights/llama-13b  \
    --data_path ~/datasets/sharegpt_20230422_clean_lang_split_identity.json \
    --bf16 True \
    --output_dir output_13b \
    --num_train_epochs 3 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 32 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "steps" \
    --eval_steps 1500 \
    --save_strategy "steps" \
    --save_steps 1500 \
    --save_total_limit 8 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.04 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap offload" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --lazy_preprocess True
jonny64 commented 1 year ago

hi, when I run training as per docs:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/user/FastChat/fastchat/train/train_mem.py:4 in <module>                                    │
│                                                                                                  │
│    1 # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.          │
│    2                                                                                             │
│    3 # Need to call this before importing transformers.                                          │
│ ❱  4 from fastchat.train.llama_flash_attn_monkey_patch import (                                  │
│    5 │   replace_llama_attn_with_flash_attn,                                                     │
│    6 )                                                                                           │
│    7                                                                                             │
│                                                                                                  │
│ /home/user/.local/lib/python3.10/site-packages/fastchat/train/llama_flash_attn_monkey_patch.py:1 │
│ 1 in <module>                                                                                    │
│                                                                                                  │
│     8                                                                                            │
│     9 from einops import rearrange                                                               │
│    10                                                                                            │
│ ❱  11 from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func             │
│    12 from flash_attn.bert_padding import unpad_input, pad_input                                 │
│    13                                                                                            │
│    14                                                                                            │
│                                                                                                  │
│ /home/user/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:5 in <module>  │
│                                                                                                  │
│     2 import torch.nn as nn                                                                      │
│     3 import torch.nn.functional as F                                                            │
│     4                                                                                            │
│ ❱   5 import flash_attn_cuda                                                                     │
│     6                                                                                            │
│     7                                                                                            │
│     8 def _get_block_size(device, head_dim, is_dropout):                                         │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ImportError: /home/user/.local/lib/python3.10/site-packages/flash_attn_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104impl8GPUTrace13gpuTraceStateE
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 79544) of binary: /usr/bin/python3
Fatal Python error: Segmentation fault

tried to reinstall/recompile flash_attn, same thing

what pytorch, flash_attn are you using? could you freeze it in requirements.txt

kswanjitsu commented 1 year ago

hi, when I run training as per docs:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/user/FastChat/fastchat/train/train_mem.py:4 in <module>                                    │
│                                                                                                  │
│    1 # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.          │
│    2                                                                                             │
│    3 # Need to call this before importing transformers.                                          │
│ ❱  4 from fastchat.train.llama_flash_attn_monkey_patch import (                                  │
│    5 │   replace_llama_attn_with_flash_attn,                                                     │
│    6 )                                                                                           │
│    7                                                                                             │
│                                                                                                  │
│ /home/user/.local/lib/python3.10/site-packages/fastchat/train/llama_flash_attn_monkey_patch.py:1 │
│ 1 in <module>                                                                                    │
│                                                                                                  │
│     8                                                                                            │
│     9 from einops import rearrange                                                               │
│    10                                                                                            │
│ ❱  11 from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func             │
│    12 from flash_attn.bert_padding import unpad_input, pad_input                                 │
│    13                                                                                            │
│    14                                                                                            │
│                                                                                                  │
│ /home/user/.local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:5 in <module>  │
│                                                                                                  │
│     2 import torch.nn as nn                                                                      │
│     3 import torch.nn.functional as F                                                            │
│     4                                                                                            │
│ ❱   5 import flash_attn_cuda                                                                     │
│     6                                                                                            │
│     7                                                                                            │
│     8 def _get_block_size(device, head_dim, is_dropout):                                         │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ImportError: /home/user/.local/lib/python3.10/site-packages/flash_attn_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN3c104impl8GPUTrace13gpuTraceStateE
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 79544) of binary: /usr/bin/python3
Fatal Python error: Segmentation fault

tried to reinstall/recompile flash_attn, same thing

what pytorch, flash_attn are you using? could you freeze it in requirements.txt

I am also getting this error when training with skypilot.

A-runaaaa commented 1 year ago

image i am also getting this error

merrymercy commented 1 year ago

This script should work on 8xA100 (40GB). https://github.com/lm-sys/FastChat/blob/main/scripts/train_vicuna_13b.sh

tokestermw commented 1 year ago

As a quick fix, I think these versions work:

pip install flash-attn==1.0.3.post0 triton==2.0.0.dev20221202
gary9630 commented 1 year ago

HI @merrymercy

I've also encountered the problem similar to @A-runaaaa when fine tuning the Vicuna-13B model.

The error happened when the model finished training and tries to save out the model.

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75870 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75871 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75872 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75873 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75874 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75875 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75877 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -9) local_rank: 6 (pid: 75876) of binary:

....

torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
===========================================================
fastchat/train/train_mem.py FAILED
-----------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
-----------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-06-13_15:00:48
  host      : xxxxx
  rank      : 6 (local_rank: 6)
  exitcode  : -9 (pid: 75876)
  error_file: <N/A>
  traceback : Signal 9 (SIGKILL) received by PID 75876
===========================================================

After digging it for a while, I find when saving the model, the memory consumption keeps growing until exhausting all of memory on the machine, and the process will be killed.

image

The question I want to ask is how many memory is enough for fine-tuned Vicuna-13B model saving. I use 669G memory on my case but it turns out not enough. What is your case for successfully fine tuning and saving the Vicuna-13B model?

Again, thank you for all your hard work. I really appreciate that!

chengming1108 commented 1 year ago

HI @merrymercy

I've also encountered the problem similar to @A-runaaaa when fine tuning the Vicuna-13B model.

The error happened when the model finished training and tries to save out the model.

WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75870 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75871 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75872 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75873 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75874 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75875 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 75877 closing signal SIGTERM
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -9) local_rank: 6 (pid: 75876) of binary:

....

torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
===========================================================
fastchat/train/train_mem.py FAILED
-----------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
-----------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-06-13_15:00:48
  host      : xxxxx
  rank      : 6 (local_rank: 6)
  exitcode  : -9 (pid: 75876)
  error_file: <N/A>
  traceback : Signal 9 (SIGKILL) received by PID 75876
===========================================================

After digging it for a while, I find when saving the model, the memory consumption keeps growing until exhausting all of memory on the machine, and the process will be killed.

image

The question I want to ask is how many memory is enough for fine-tuned Vicuna-13B model saving. I use 669G memory on my case but it turns out not enough. What is your case for successfully fine tuning and saving the Vicuna-13B model?

Again, thank you for all your hard work. I really appreciate that!

i have the same ques,how do you fix it ? thanks for reply

gary9630 commented 1 year ago

@chengming1108

I extend memory from 669GB to 1.31TB and it will save model out successfully.

@merrymercy

Consider the model size of fine-tuned Vicuna-13B is around 45GB, the memory consumption during saving phase is unreasonably high compared to training and serving phases. I am wondering that is there any reasonable explanation for this?