CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.51k stars 471 forks source link

fix model state_dict retrieving in zero3 #576

Closed Jingru closed 1 year ago

Jingru commented 1 year ago

When accelerator retrieves state_dict from model, it will invoke model's zero_gather_16bit_weights_on_model_save method for zero3 enabled scenario, which requires the model to be the accelerate wrapped one.

So, We should not use base_model as argument for state_dict call, but the wrapped one with unwrap=True.

https://github.com/huggingface/accelerate/blob/7843286f2e1c50735d259fbc0084a7f1c85e00e3/src/accelerate/accelerator.py#L3083