Closed VibhuAg closed 11 months ago
I encountered the same problem, did you solve it? Looking forward to reply!
Sorry for the confusion here- you need to use the 'state'
key in the saved dict. So:
model.load_state_dict(torch.load('<PATH_TO_OUTPUT_DIR>/LATEST/policy.pt')['state'])
I have seen issue #13 and am trying to load the model weights to perform inference. So far I have trained Pythia-2.8B with SFT and then DPO as suggested in the README. I am trying to follow the instructions given in issue #13 to load the model weights after training but I cannot seem to find them. In the output folder for the training run, I see three files
policy.pt
,optimizer.pt
, andscheduler.pt
. I am trying to load the weights as follows:But I run into numerous missing key errors. What am I doing wrong?