CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.48k stars 471 forks source link

Question about trainer.save_pretrained #412

Open c-box opened 1 year ago

c-box commented 1 year ago

🚀 The feature, motivation, and pitch

Here is the reply for #365 :

Assume that you having checkpoint output is best_checkpoint/pytorch_model/mp_rank_00_model_states.pt. I guess you can try something like this:

import torch
# Import the model architecture used during training and load the weights
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead
model = AutoModelForCausalLMWithValueHead.from_pretrained("...")
model.load_state_dict(torch.load("best_checkpoint/pytorch_model/mp_rank_00_model_states.pt")["module"])

Another way that you can save your model directly to huggingface format by this refer this https://github.com/CarperAI/trlx#save-the-resulting-model-to-a-hugging-face-pretrained-language-model-ready-to-upload-to-the-hub.

And I have another question about this issue, when executing:

trainer = trlx.train(config=config, reward_fn=lambda samples, **kwargs: [float(int(sample)) for sample in samples])
trainer.save_pretrained('/path/to/output/folder/')

Is the trainer storing the last checkpoint or the best checkpoint? I suspect that it is the last checkpoint? If so, how can I save the best checkpoint so that I can load it using:

AutoModelForCausalLM.from_pretrained(path)

Alternatives

No response

Additional context

No response

maxreciprocate commented 1 year ago

@c-box Hey, you are correct, it would be the last checkpoint which would be save when executing that piece of code. However also your request was recently addressed in https://github.com/CarperAI/trlx/pull/429. Now with setting trainer.save_optimizer: False, you will be able to load the best checkpoint with just AutoModelForCausalLM.from_pretrained(path) (where path will be in trainer.checkpoint_dir/best_checkpoint and best checkpoint will be saved automatically)