Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.31k stars 3.38k forks source link

GAN example: Only one backward() call? #594

Closed alainjungo closed 4 years ago

alainjungo commented 4 years ago

In the PyTorch GAN tutorial https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html there are two backward() calls for the discriminator. How do you ensure this with your structure, where backward() gets called after the training step?

Best, Alain

williamFalcon commented 4 years ago

good catch. we don't actually support this atm.

A solution here is to allow a dictionary of options for each optimizer which allows arbitrary number of calls. Something like:

def configure_optimizers(self,...):
    opt_G = {'optimizer': Adam(...), 'frequency': 2, 'lr_scheduler': LRScheduler(...)}
    opt_D = {'optimizer': Adam(...), 'frequency': 1, 'lr_scheduler': LRScheduler(...)}

    return opt_G, opt_D

Here G would be called twice back to back, and G once after

@jeffling @neggert

But not sure if this is a clean user experience

jeffling commented 4 years ago

@williamFalcon that API would work for us.

@alainjungo: Some workarounds, none of them ideal:

  1. Skip every other generator step. You'll have to double your iterations
  2. Double learning rate on D (not algorithmically the same but can have similar effect)
kwanUm commented 4 years ago

You can always return one loss at the training step that captures both losses. In the dcgan example, this would look like errD = errD_real + errD_fake; errD.backword();.

Not sure if I'm correct here, but this seems equivalent and matches PTL paradigms.

asafmanor commented 4 years ago

I have implemented an API that allows returning optimizer, lr_schedulers, optimizer_frequencies, and then based on the batch_idx will determine the current optimizer to use.

If there is an agreement on this API, I'll proceed to testing, documenting and submitting a PR.

Another option would be to allow returning a tuple of dictionaries as @williamFalcon suggested. that would be a minor change for me and I am willing to that if it is agreed upon.

asafmanor commented 4 years ago

I am adding here that the official implementation of WGAN-GP (for example) needs this feature in order to converge. This is a very fundamental feature in GAN training. image

Borda commented 4 years ago

I like the @williamFalcon which seems clear to me... @asafmanor mind sens a PR or describe what API you have in mind? cc: @PyTorchLightning/core-contributors any comments on this?

asafmanor commented 4 years ago

I'll implement the @williamFalcon API and send a detailed PR over the weekend 👍