eric-mitchell / direct-preference-optimization

Reference implementation for DPO (Direct Preference Optimization)
Apache License 2.0
2.18k stars 180 forks source link

How to load trained model for inference? #49

Closed VibhuAg closed 11 months ago

VibhuAg commented 1 year ago

I have seen issue #13 and am trying to load the model weights to perform inference. So far I have trained Pythia-2.8B with SFT and then DPO as suggested in the README. I am trying to follow the instructions given in issue #13 to load the model weights after training but I cannot seem to find them. In the output folder for the training run, I see three files policy.pt, optimizer.pt, and scheduler.pt. I am trying to load the weights as follows:

import torch
import transformers

model = transformers.AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-2.8b')
model.load_state_dict(torch.load('<PATH_TO_OUTPUT_DIR>/LATEST/policy.pt'))

But I run into numerous missing key errors. What am I doing wrong?

Leonnnnnn929 commented 12 months ago

I encountered the same problem, did you solve it? Looking forward to reply!

eric-mitchell commented 11 months ago

Sorry for the confusion here- you need to use the 'state' key in the saved dict. So:

model.load_state_dict(torch.load('<PATH_TO_OUTPUT_DIR>/LATEST/policy.pt')['state'])