richarddwang / electra_pytorch

Pretrain and finetune ELECTRA with fastai and huggingface. (Results of the paper replicated !)
324 stars 41 forks source link

How do I extract and save the discriminator from the checkpoint? #20

Closed PhilipMay closed 3 years ago

PhilipMay commented 3 years ago

Hi @richarddwang , my question is: How do I extract and save the discriminator from the checkpoint?

I can load it with model = torch.load("path_to_checkpoint")

but after that I have to split it into generator and discriminator and save it in the .bin format.

You you maybe provide some hints?

Many thanks Philip

PhilipMay commented 3 years ago

I think I have to create a Learner again and then use the load method of the object - right? Is there a more elegant way to just load the model and not the Learner?

richarddwang commented 3 years ago

I have no good solution here.

In my case, loading pretrained electra checkpoint to create a glue finetuning model for example. I ckpt = torch.load("path_to_checkpoint") and then just take part of its states_dict according to names from model.named_parameters().

richarddwang commented 3 years ago

Feel free to tag me if there are other questions.