thunlp / OpenPrompt

An Open-Source Framework for Prompt-Learning.
https://thunlp.github.io/OpenPrompt/
Apache License 2.0
4.38k stars 455 forks source link

Save and re-load a trained prompt model #222

Open agrimaseth opened 1 year ago

agrimaseth commented 1 year ago

Currently, I am saving the model using the following: torch.save(prompt_model.state_dict(), PATH)

How can we load this back to test performance on other data? Pytorch tutorial says to use the following. model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()

What would TheModelClass be? An example would be really appreciated.

agrimaseth commented 1 year ago

@ShengdingHu could you please help with this? Using the above load command doesn't work.