eriklindernoren / PyTorch-GAN

PyTorch implementations of Generative Adversarial Networks.
MIT License
16.22k stars 4.05k forks source link

load and save checkpoint logic error #158

Open John155 opened 3 years ago

John155 commented 3 years ago

Save checkpoint code, for example in the MUNIT GAN:

parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model checkpoints")
for epoch in range(opt.epoch, opt.n_epochs):
    for i, batch in enumerate(dataloader):

  if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
  torch.save(Enc1.state_dict(), "saved_models/%s/Enc1_%d.pth" % (opt.dataset_name, epoch))


The Problem
You will save the first checkpoint with epoch 0, but in the restore function, u cannot load form an epoch 0.
Load checkpoints, for example in the MUNIT GAN:

if opt.epoch != 0:
   # Load pretrained models

else: 
   # Initialize weights
nixczhou commented 3 years ago

i think you can change the if statement to enable loading from epoch 0.