mindspore-lab / mindrlhf

Apache License 2.0
26 stars 12 forks source link

GPU memory is not enough when training LLaMA2-7B PPO #61

Closed dhcode-cpp closed 7 months ago

dhcode-cpp commented 8 months ago

由于在910 32G x 8 显卡运行 PPO 训练会出现显存不足问题,采用two-stage 方式

但数据生成存在错误,请问有在 8x910A 32GB 跑通 LLaMA2-7B 的样例吗?

修改 scripts/run_distribute_two_stages.sh

ROOT_PATH=`pwd`
DATA_DIR=$1
export RANK_TABLE_FILE=$2
RANK_START=$3
LOCAL_DEVICE_NUM=$4
LOCAL_DEVICE_NUM=$4
SFT_MODEL_PATH=$5
REWARD_MODEL_PATH=$6
SAVE_DATA_PATH=$7

for((i=0;i<${LOCAL_DEVICE_NUM};i++));
do
    rm ${ROOT_PATH}/device$i/ -rf
    mkdir ${ROOT_PATH}/device$i
    cd ${ROOT_PATH}/device$i || exit
    export RANK_ID=$[i+RANK_START]
    export DEVICE_ID=$i
    python3 ${ROOT_PATH}/make_experience.py   --dataset_dir ${DATA_DIR} --sft_model_path ${SFT_MODEL_PATH} \
  --critic_model_path ${REWARD_MODEL_PATH} --reward_model_path ${REWARD_MODEL_PATH} --save_data_file ${SAVE_DATA_PATH} &
done
wait

for((i=0;i<${LOCAL_DEVICE_NUM};i++));
do
    rm ${ROOT_PATH}/device$i/ -rf
    mkdir ${ROOT_PATH}/device$i
    cd ${ROOT_PATH}/device$i || exit
    export RANK_ID=$[i+RANK_START]
    export DEVICE_ID=$i
    python3 ${ROOT_PATH}/ppo_train.py --dataset_dir ${SAVE_DATA_PATH} --sft_model_path ${SFT_MODEL_PATH} \
  --critic_model_path ${REWARD_MODEL_PATH} --reward_model_path ${REWARD_MODEL_PATH}  &
done

运行以下脚本

bash ./scripts/run_distribute_two_stages.sh \
/home/ma-user/work/mindrlhf/data/CValues-Comparison/test_1k_ppo.mindrecord \
/home/ma-user/work/mindrlhf/jobstart_hccl.json 0 8 \
/home/ma-user/work/mindrlhf/model_configs/llama2_config/llama2_7b.yaml \
/home/ma-user/work/mindrlhf/model_configs/llama2_config/llama2_7b_rm.yaml  \
/home/ma-user/work/mindrlhf/data/CValues-Comparison/test_1k_two_stage_ppo.mindrecord

能生成 PPO 数据文件, 保存存在问题

Total number of samples: 1
Total number of samples: 1
Traceback (most recent call last):
Total number of samples: 1
  File "/home/ma-user/work/mindrlhf/make_experience.py", line 99, in <module>
    trainer.make_experience(num_rollouts=ppo_config.num_rollouts)
  File "/home/ma-user/work/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 382, in make_experience
    writer.commit()
  File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/mindrecord/filewriter.py", line 454, in commit
    run_rlhf(args)
  File "/home/ma-user/work/mindrlhf/make_experience.py", line 92, in run_rlhf
    trainer.make_experience(num_rollouts=ppo_config.num_rollouts)
  File "/home/ma-user/work/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 382, in make_experience
    run_rlhf(args)
  File "/home/ma-user/work/mindrlhf/make_experience.py", line 92, in run_rlhf
    trainer.make_experience(num_rollouts=ppo_config.num_rollouts)
  File "/home/ma-user/work/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 382, in make_experience
    self.save_ppoelement(self.ppo_config.save_data_file)
  File "/home/ma-user/work/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 175, in save_ppoelement
    self.save_ppoelement(self.ppo_config.save_data_file)
  File "/home/ma-user/work/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 175, in save_ppoelement
    self._generator.build()
  File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/mindrecord/shardindexgenerator.py", line 54, in build
    writer.commit()
    self.save_ppoelement(self.ppo_config.save_data_file)
  File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/mindrecord/filewriter.py", line 454, in commit
  File "/home/ma-user/work/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 175, in save_ppoelement
    writer.commit()
  File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/mindrecord/filewriter.py", line 454, in commit
    writer.commit()
  File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/mindrecord/filewriter.py", line 454, in commit
    self._generator.build()
  File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/mindrecord/shardindexgenerator.py", line 54, in build
    ret = self._generator.build()
    self._generator.build()
RuntimeError: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. 

运行到PPO时报错

Traceback (most recent call last):
  File "/home/ma-user/work/mindrlhf/ppo_train.py", line 103, in <module>
    run_rlhf(args)
  File "/home/ma-user/work/mindrlhf/ppo_train.py", line 89, in run_rlhf
    trainer = PPOTrainer(ppo_config=ppo_config, sft_model_config=sft_model_config, ref_model_config=ref_model_config,
  File "/home/ma-user/work/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 85, in __init__
    self.prompt_iterator = self.prompt_dataloader.create_tuple_iterator()
  File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/dataset/engine/validators.py", line 988, in new_method
    return method(self, *args, **kwargs)
  File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/dataset/engine/datasets.py", line 1579, in create_tuple_iterator
    return TupleIterator(self, columns, num_epochs, output_numpy, do_copy)
  File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 290, in __init__
    super().__init__(dataset, num_epochs, output_numpy, do_copy)
  File "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/dataset/engine/iterators.py", line 90, in __init__
    consumer.Init(self.ir_tree, init_step, dataset_size)
RuntimeError: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. 

------------------------------------------------------------------
- Dataset Pipeline Error Message: 
------------------------------------------------------------------
[ERROR] Invalid column, column name: prompt_ids does not exist.

------------------------------------------------------------------
- C++ Call Stack: (For framework developers) 
------------------------------------------------------------------
mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc(100).
ChessQian commented 8 months ago

两阶段训练的时候,save_data_file是真正的是输入数据的路径,而不是dataset_dir