williamSYSU / TextGAN-PyTorch

TextGAN is a PyTorch framework for Generative Adversarial Networks (GANs) based text generation models.
MIT License
892 stars 205 forks source link

How can I restore a saved model? #9

Closed grmarco closed 5 years ago

williamSYSU commented 5 years ago

Currently you can only load pre-trained Generator or Discriminator from pretrained_gen_path and pretrained_dis_path respectively (refer to config.py). The simplest way to load pre-trained Generator from save/**/models is to directly replace the state dict in pretrained_gen_path with the state dict in save/**/models. If you want to load the checkpoint of adversarial training, you can easily update the code of the function self._save() in instructor.py with the code of saving Discriminator. Therefore, you can regard the state dict of model saved during adversarial training as the state dict of model saved after the pre-training.