GanjinZero / RRHF

[NIPS2023] RRHF & Wombat
780 stars 49 forks source link

RRHF only works on llama model. #8

Closed Taekyoon closed 1 year ago

Taekyoon commented 1 year ago

I tried to train GPT2 and GPTJ models, but it didn't work because of the OOM issue. Even I tried to 1.3b GPT2 model, it didn't work for the same reason.

For GPTNeo model, the train doesn't work because of this error.

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got CUDABFloat16Type instead (while checking arguments for embedding)

Something caught on the type issue on the embedding layer which didn't happen on llama, gpt2, and gptJ.

Do you have any ideas to solve these problems?

GanjinZero commented 1 year ago

What is your batch size for training? (Query count and reply count) We suggest use query batch size =1. How do you modify your code when training gptneo?

Taekyoon commented 1 year ago

I've tried the same parameter written in the paper(1 batch per device, 8 step grad accumulation), and I didn't modified any codes in this project. I think that I can reduce the number of gradient accumulation. But it doesn't make sense because in case of gpt2, parameter size was much smaller(1.3b) than the size of llama.

GanjinZero commented 1 year ago

What is your GPU device?

Taekyoon commented 1 year ago

I use the same spec on the paper mentioned. (8xA100-SXM-80GB)

GanjinZero commented 1 year ago

I will try to reproduce your bug, and give you feedback later.

GanjinZero commented 1 year ago

I have the bug of Exception: Could not find the transformer layer class to wrap in the model. while I am using gpt2-xl. I think you need to modify the script of fsdp_transformer_layer_cls_to_wrap. I will check it out how to modify this later.

Yuanhy1997 commented 1 year ago

For gpt2, it would be "GPT2Block". For other model, you can pass the name of the layer class to the fsdp_transformer_layer_cls_to_wrap arg. You can find the names in the modeling_xxx.py in huggingface/transformers code repo.

Taekyoon commented 1 year ago

I've already tried to add "GPT2Block like this below :)

    --fsdp_transformer_layer_cls_to_wrap 'GPT2Block' # same thing for 'GPTNeoXLayer' \

I also tried these 2 model architectures on Alpaca repo with the same environment and it worked. Maybe something issues with RRHFTrainer class. Can you look up on this implementation. I think this trainer class is the only difference between Alpaca project.

GanjinZero commented 1 year ago

This is my bash

export MODEL_PATH='gpt2-xl'
export SAVE_PATH='./outputs/debuggptj'
export MASTER_ADDR="localhost"
export MASTER_PORT="22"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export WANDB_DISABLED=true
wandb offline

cd ./RRHF-main
python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env train.py \
    --model_name_or_path $MODEL_PATH \
    --data_path ./chatgpt_train.json \
    --bf16 True \
    --output_dir $SAVE_PATH \
    --num_train_epochs 3 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 100 \
    --save_total_limit 40 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'GPT2Block' \
    --tf32 True --model_max_length 192 --rrhf_weight 1

{'loss': 3.9652, 'learning_rate': 3.0769230769230774e-07, 'epoch': 0.0}
{'loss': 4.039, 'learning_rate': 6.153846153846155e-07, 'epoch': 0.0}
{'loss': 4.8444, 'learning_rate': 9.230769230769232e-07, 'epoch': 0.0}
{'loss': 5.0971, 'learning_rate': 1.230769230769231e-06, 'epoch': 0.01}

Taekyoon commented 1 year ago

I tried the same script without some ENV args.

export MODEL_PATH=$1
export SAVE_PATH=$2
export DATA_PATH=$3
export MASTER_ADDR="localhost"
export MASTER_PORT="32123"

torchrun --master_port ${MASTER_PORT} --nproc_per_node=8 train_alpaca_prompt.py \
    --model_name_or_path $MODEL_PATH \
    --data_path $DATA_PATH \
    --bf16 True \
    --output_dir $SAVE_PATH \
    --num_train_epochs 3 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 100 \
    --save_total_limit 40 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'GPTNeoXLayer' \
    --tf32 True --model_max_length 512 --rrhf_weight 1

At the beginning, it worked as you showed. But after few steps, it bursted..

Yuanhy1997 commented 1 year ago

Can you please attach more details of the error?

GanjinZero commented 1 year ago

For gpt2-xl, with model_max_length=192, we use 20GB per GPU. With model_max_length=512, we use 50GB per GPU.

Taekyoon commented 1 year ago

Please give me 2-3 hours, I'll give more details.

GanjinZero commented 1 year ago

For GPT-NeoX, you may need to use a smaller model_max_length since our exact batch size is equal to per_device_train_batch_size * query_count which may be higher than 4 from alpaca.

Taekyoon commented 1 year ago

To give further information, I've tried to train these three models.

Taekyoon commented 1 year ago

I think GPT2 model is okay! Found out it was length issue. However, GPT-NeoX still has an issue even if I tried to reduce model_max_length. Error message shows like this below

Traceback (most recent call last):
  File "/home/irteam/work/projects/news_gpt_research/RRHF/train_alpaca_prompt.py", line 332, in <module>
    train()
  File "/home/irteam/work/projects/news_gpt_research/RRHF/train_alpaca_prompt.py", line 326, in train
    trainer.train()
  File "/home/irteam/work/projects/news_gpt_research/RRHF/transformers/src/transformers/trainer.py", line 1662, in train
    return inner_training_loop(
  File "/home/irteam/work/projects/news_gpt_research/RRHF/transformers/src/transformers/trainer.py", line 1927, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/irteam/work/projects/news_gpt_research/RRHF/transformers/src/transformers/trainer.py", line 2699, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/irteam/work/projects/news_gpt_research/RRHF/train_alpaca_prompt.py", line 282, in compute_loss
    logits = model(input_ids=inputs.get('input_ids'), attention_mask=inputs.get('attention_mask'))[0] # (batch * cand) * L * V
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/irteam/work/projects/news_gpt_research/RRHF/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 662, in forward
    outputs = self.gpt_neox(
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/irteam/work/projects/news_gpt_research/RRHF/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 518, in forward
    inputs_embeds = self.embed_in(input_ids)
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
Traceback (most recent call last):
  File "/home/irteam/work/projects/news_gpt_research/RRHF/train_alpaca_prompt.py", line 332, in <module>
    return F.embedding(
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got CUDABFloat16Type instead (while checking arguments for embedding)
    train()
  File "/home/irteam/work/projects/news_gpt_research/RRHF/train_alpaca_prompt.py", line 326, in train
    trainer.train()
  File "/home/irteam/work/projects/news_gpt_research/RRHF/transformers/src/transformers/trainer.py", line 1662, in train
    return inner_training_loop(
  File "/home/irteam/work/projects/news_gpt_research/RRHF/transformers/src/transformers/trainer.py", line 1927, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/irteam/work/projects/news_gpt_research/RRHF/transformers/src/transformers/trainer.py", line 2699, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/irteam/work/projects/news_gpt_research/RRHF/train_alpaca_prompt.py", line 282, in compute_loss
    logits = model(input_ids=inputs.get('input_ids'), attention_mask=inputs.get('attention_mask'))[0] # (batch * cand) * L * V
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 748, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/irteam/work/projects/news_gpt_research/RRHF/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 662, in forward
    outputs = self.gpt_neox(
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/irteam/work/projects/news_gpt_research/RRHF/transformers/src/transformers/models/gpt_neox/modeling_gpt_neox.py", line 518, in forward
    inputs_embeds = self.embed_in(input_ids)
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/irteam/.conda/envs/koalpaca_env/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got CUDABFloat16Type instead (while checking arguments for embedding)

I think this could be GPT-NeoX model layer issue, so I will look down on this. I'll reopen this issue when something comes up.

Thank you for your help! :)