Closed Constructor-Sun closed 2 days ago
Hey @Constructor-Sun, I encountered a similar issue using PPOv2Trainer
. Here is what I think is going on.
In PPOv2Trainer
the policy and value models are wrapped in PolicyAndValueWrapper
object and it is that object that is initialized by DeepSpeed. However in save_model
method of PPOv2Trainer
we temporarily overwrite the self.deepspeed
attribute with policy model, which itself is not a DeepSpeedEngine
object and therefore we get the AttributeError
.
I've come up with a fix that seems to work.
1) Firstly we need to change the save_model
method to not replace the self.deepspeed
attribute. This way the save_model
method from HF trainer will use the DeepSpeedEngine
object to gather weights from all processes for both the value and the policy model.
2) Then we need to override the _save
method to extract only the policy model weights (when deepspeed is enabled) and call the actual _save
method passing the modified state_dict
.
class FixZero3CheckpointPPOv2Trainer(PPOv2Trainer):
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
backup_model = self.model
self.model = self.model.policy # save only the policy
Trainer.save_model(self, output_dir, _internal_call)
self.model = backup_model
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if self.is_deepspeed_enabled:
state_dict = {name.removeprefix('policy.'): param for name, param in state_dict.items()
if name.startswith('policy.')}
super()._save(output_dir, state_dict)
Let me know if this solution works with your setup. I would make a pull request with this fix but it feels a bit hacky. Maybe someone will come up with a more elegant solution :)
Hey @Constructor-Sun, I encountered a similar issue using
PPOv2Trainer
. Here is what I think is going on. InPPOv2Trainer
the policy and value models are wrapped inPolicyAndValueWrapper
object and it is that object that is initialized by DeepSpeed. However insave_model
method ofPPOv2Trainer
we temporarily overwrite theself.deepspeed
attribute with policy model, which itself is not aDeepSpeedEngine
object and therefore we get theAttributeError
.I've come up with a fix that seems to work.
- Firstly we need to change the
save_model
method to not replace theself.deepspeed
attribute. This way thesave_model
method from HF trainer will use theDeepSpeedEngine
object to gather weights from all processes for both the value and the policy model.- Then we need to override the
_save
method to extract only the policy model weights (when deepspeed is enabled) and call the actual_save
method passing the modifiedstate_dict
.class FixZero3CheckpointPPOv2Trainer(PPOv2Trainer): def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): backup_model = self.model self.model = self.model.policy # save only the policy Trainer.save_model(self, output_dir, _internal_call) self.model = backup_model def _save(self, output_dir: Optional[str] = None, state_dict=None): if self.is_deepspeed_enabled: state_dict = {name.removeprefix('policy.'): param for name, param in state_dict.items() if name.startswith('policy.')} super()._save(output_dir, state_dict)
Let me know if this solution works with your setup. I would make a pull request with this fix but it feels a bit hacky. Maybe someone will come up with a more elegant solution :)
Thanks for your help! I also found out that the bug happens because the policy isn't a DeepSpeedEngine
object. The policy and value model are bundled together in PolicyAndValueWrapper
and then wrapped into DeepSpeedEngine
using accelerate.prepare()
.
I have tested your solution, and it works well!
My solution to this is to rewrite save_model
as:
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
backup_model = self.model
self.model = self.model.policy # save only the policy
if self.is_deepspeed_enabled:
backup_deepspeed = self.deepspeed
self.deepspeed = self.model
os.makedirs(output_dir, exist_ok=True)
self.model.save_pretrained(output_dir)
self.model = backup_model
if self.is_deepspeed_enabled:
self.deepspeed = backup_deepspeed
But I didn't check if this was correct.
Thanks a lot, once more!
System Info
Information
Tasks
examples
folderReproduction
When I tried to use ppo-v2 for my own task, I met the following error in the checkpoint saving:
Here
MistralForCausalLM
is my policy.I used the following command to run examples/scripts/ppo/ppo_tldr.py file (in which I only modified the dataset loading method and did not make any other changes):
(Here I set total_episodes to 10 just to save time since errors occur at the checkpoint saving stage.) The accelerate config file is:
It seems that this error occurs then I use zero-3 and mixed precision
bf16
. I tried to switch to zero-2 but only got out-of-memory error. Meanwhile, I find out thatzero_gather_16bit_weights_on_model_save
is a method in DeepSpeedEngine class. Could someone help me solve or analyze this issue? I have been struggling with it for several days.😭😭Expected behavior
PPOv2 trainer can save checkpoints when using zero-3.