I'm trying to run the gpt2 model and although the error mentioned #38 seems to have been fixed in the latest mindrlhf version, now I get a different error detailed below:
Traceback (most recent call last):
File "~/mindrlhf/train.py", line 103, in <module>
run_rlhf(args)
File "~/mindrlhf/train.py", line 89, in run_rlhf
trainer = PPOTrainer(ppo_config=ppo_config, sft_model_config=sft_model_config, ref_model_config=ref_model_config,
File "~/mindrlhf/mindrlhf/trainer/ppo_trainer.py", line 89, in __init__
policy_model = CausalLMHydraWithValueHead(sft_model_config, self.ppo_config)
File "~/mindrlhf/mindrlhf/models/ppo_models.py", line 116, in __init__
self.lm_head.pipeline_stage = model_config.parallel_config.pipeline_stage - 1
File "~/venv/lib/python3.9/site-packages/mindspore/nn/cell.py", line 387, in __getattr__
raise AttributeError("The '{}' object has no attribute '{}'.".format(type(self).__name__, name))
AttributeError: The 'CausalLMHydraWithValueHead' object has no attribute 'lm_head'.
I run with the following configuration:
mpirun -np 1 python train.py --sft_model_path model_configs/gpt2_config/run_gpt2_124m.yaml --critic_model_path model_configs/gpt2_config/run_gpt2_124m.yaml --reward_model_path model_configs/gpt2_config/run_gpt2_1_3b.yaml --dataset_dir TLDR_data/train/tldr_train_prompts.mindrecord --device_target GPU
Could you please advise how to fix it? Thanks
mindrlhf version: b48c9ae
mindspore version: 2.2.0
mindformers version 0.8.0
Hi,
I'm trying to run the
gpt2
model and although the error mentioned #38 seems to have been fixed in the latestmindrlhf
version, now I get a different error detailed below:I run with the following configuration:
mpirun -np 1 python train.py --sft_model_path model_configs/gpt2_config/run_gpt2_124m.yaml --critic_model_path model_configs/gpt2_config/run_gpt2_124m.yaml --reward_model_path model_configs/gpt2_config/run_gpt2_1_3b.yaml --dataset_dir TLDR_data/train/tldr_train_prompts.mindrecord --device_target GPU
Could you please advise how to fix it? Thanks
mindrlhf version: b48c9ae mindspore version: 2.2.0 mindformers version 0.8.0