OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.72k stars 160 forks source link

Fixed error due to 'margin' variable type being list in rm_trainer.py #247

Closed StwayneXG closed 3 months ago

StwayneXG commented 3 months ago

While training the reward model using train_rm_llama.sh, an error occurred:

    self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict)
  File "***/OpenRLHF/openrlhf/trainer/rm_trainer.py", line 180, in save_logs_and_checkpoints
    self.evaluate(self.eval_dataloader, global_step)
  File "***/OpenRLHF/openrlhf/trainer/rm_trainer.py", line 203, in evaluate
    margin = margin.to(torch.cuda.current_device())
             ^^^^^^^^^
AttributeError: 'list' object has no attribute 'to'

Should be fixed by updating the code to margin = torch.tensor(margin).to(torch.cuda.current_device()) like in the self.fit(self, args) function above.