sunzeyeah / RLHF

Implementation of Chinese ChatGPT
282 stars 36 forks source link

基于ChatGLM2的RLHF训练问题 #23

Open UltraZeroyH opened 10 months ago

UltraZeroyH commented 10 months ago

[2023-08-12 01:22:11,409] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.10.0, git-hash=unknown, git-branch=unknown Traceback (most recent call last): File "/root/inpc_projects/ChatGLM-6B/rlhf_again/src/train_rlhf.py", line 373, in main() File "/root/inpc_projects/ChatGLM-6B/rlhf_again/src/train_rlhf.py", line 237, in main rlhf_engine = DeepSpeedRLHFEngine( File "/root/inpc_projects/ChatGLM-6B/rlhf_again/src/models/rlhf_engine.py", line 146, in init self.ref = self._init_ref( File "/root/inpc_projects/ChatGLM-6B/rlhf_again/src/models/rlhf_engine.py", line 245, in _init_ref refengine, * = deepspeed.initialize(model=ref_model, File "/root/anaconda3/envs/rlhf/lib/python3.10/site-packages/deepspeed/init.py", line 157, in initialize config_class = DeepSpeedConfig(config, mpu) File "/root/anaconda3/envs/rlhf/lib/python3.10/site-packages/deepspeed/runtime/config.py", line 769, in init self._configure_train_batch_size() File "/root/anaconda3/envs/rlhf/lib/python3.10/site-packages/deepspeed/runtime/config.py", line 942, in _configure_train_batch_size self._batch_assertion() File "/root/anaconda3/envs/rlhf/lib/python3.10/site-packages/deepspeed/runtime/config.py", line 890, in _batch_assertion assert train_batch == micro_batch grad_acc self.world_size, ( AssertionError: Check batch related parameters. train_batch_size is not equal to micro_batch_per_gpu gradient_acc_step world_size 1024 != 4 1 8

在使用ChatGLM2作为sft和reward模型,在A100*8的环境上训练的时候,在第三阶段train_rlhf时出现如上报错,尝试了很多方法都没有解决,deepspeed版本是0.10.0,奇怪的点是当--actor_zero_stage是2的时候,能够成功装载actor模型,但是装载reference的时候仍然会报这个错,想请问一下作者有什么建议吗?

sunzeyeah commented 10 months ago

这个原因应该是系统认为在运行deepspeed.initialize()之前world_size一直都是1,所以ds_config['train_batch_size']不需要乘上world_size。只能在运行deepspeed.initialize()之前,才把ds_config['train_batch_size']改为乘上world_size

RL部分的代码还没来得及修复这个问题,具体可以参见pretrain_wo_trainer.py 第220-221行pretrain_wo_trainer.py 第292行