lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.07k stars 143 forks source link

Saving checkpoints during training and loading #57

Closed ylhsieh closed 3 years ago

ylhsieh commented 3 years ago

Hi, thanks for this awesome repo. I would like to know the correct way of saving a AutoregressiveWrapper model? Is it torch.save(model.net.state_dict(), 'checkpoint.pt')?. Then how should I load it back for generation? Thanks.

asigalov61 commented 3 years ago

@ylhsieh

Try this:

Save:

torch.save(model.state_dict(), '/content/model.pth')

checkpoint = {'state_dict': model.state_dict(),'optimizer' :optim.state_dict()}
torch.save(checkpoint, '/content/model_sd_opt.pth')

Load:

torch.load('/content/model_sd_opt.pth')

Although loading maybe tricky due to many reasons. So try to read the docs/try a few things.

Hope this is helpful.

ylhsieh commented 3 years ago

Hi thanks so much for your help! I end up using deepspeed and the following commands

model_engine.save('checkpoint_dir', f'{epoch}_{step}')
model_engine.load('checkpoint_dir')
asigalov61 commented 3 years ago

@ylhsieh You are welcome. I am glad you found a working solution. Yes, deepspeed is a good way to go too, so whatever works :)