rosinality / style-based-gan-pytorch

Implementation A Style-Based Generator Architecture for Generative Adversarial Networks in PyTorch
Other
1.1k stars 232 forks source link

Continue training from checkpoint for same resolution #113

Closed flyfisher123 closed 3 years ago

flyfisher123 commented 3 years ago

Thanks a lot for sharing your code. It's been a great help for me.

Is there a way to load a saved .model and continue training at the same size? Like train_step-7.model is saved after completing training for 256x256. Can I load it to continue training at 256x256 instead of restarting from train_step-6.model? There are multiple reasons why I would like to do this:

1) My machine is prone to crashes and I don't want to restart from the next lower resolution every time. I could simply save all weights every few thousand steps instead of only when the phase changes. 2) I could continue training at my highest resolution, if I realise that a few more iterations might have been beneficial.

rosinality commented 3 years ago

Replace this (https://github.com/rosinality/style-based-gan-pytorch/blob/master/train.py#L241) to this (https://github.com/rosinality/style-based-gan-pytorch/blob/master/train.py#L100) and additionally save used_sample, then you can load checkpoints and set used_sample to saved one.

flyfisher123 commented 3 years ago

Thanks for your help. I've tried that, but that way, the loss for both G and D are very high and the images produced look like they're generated with a model for the the wrong size.

I can send some images or snippets later, I don't have access right now.

flyfisher123 commented 3 years ago

I just used the saved model and generate.py to create images of different sizes. The saved model is for 256x256 and I use --init_size 256. The samples produced when restarting training are 256x256, but look exactly as those generated from the save model by generate.py with --size 128.

rosinality commented 3 years ago

I think it could be due to alpha (used_sample) is not set appropriately or g_running is not loaded.

flyfisher123 commented 3 years ago

Thanks. I'll test your suggestions. Edit: Got it to work. Thanks. I used a wrong used_sample, therefore alpha was wrong.