Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.96k stars 3.35k forks source link

GAN training with Pytorch Lightning is broken. #591

Closed ayberkydn closed 4 years ago

ayberkydn commented 4 years ago

I was trying to train a DCGAN on my dataset but it wouldn't work in any means until I detach the training logic from Lightning and run the code without it. It was not working when my training logic is in Lightning module. I checked the gan examples in the docs and also multiple optimizer things. After 2 days of headaches, source code inspections and putting numerous print statements in the lightning source code, I found the culprit.

GAN training with Pytorch Lightning is simply broken. The culprit is only calling optimizer.zero_grad() after optimizer.step() since it clears the gradients of Generator or Discriminator only. Before the other network's weights are updated for, say, Generator; "loss.backward()" is called and it updates gradients for all parameters, but when optimizer.zero_grad() is called after the parameters are updated, only Generator's gradients are cleaned. So when it comes the Discriminator loss.backward(), leftover gradients are accumulated for Discriminator parameters and it just messes up with everything. Any kind of GAN training is impossible with this settings. That's why you can not find any GAN implementations with Pytorch Lightning on the internet.

Possible solutions -Putting a warning in the docs or on the console after detecting multiple optimizers are defined like "if you train GAN, don't forget to zero all gradients by overriding the optimizer_step method or just reset gradients in your training loop before returning loss dictionary". (would be weird, honestly)

(By the way, on_before_zero_grad is not seem to be called anywhere right now. Maybe the issue can be fixed with a default behavior of that method, alternatively.)

I just sent a pull request which implements the last option and updates the docs.

williamFalcon commented 4 years ago

sorry for the trouble! good catch though. Thanks for the PR, we’ll merge it once tests pass :)

feel free to submit the PR when ready

ayberkydn commented 4 years ago

Whoops, I just have failed to submit the PR. I think it is ready now. Appreciate your work very much. Thanks for the great library :)

jeffling commented 4 years ago

hi, at work we train GANs successfully without any optimizer overrides. We also have the zero_grad before optimizer step

my teammate recommends that you put no_grad around your discriminator forward, and your runtime would be faster than calling zero_grad in the proposed way. Would that work for you?

dbarnhart commented 4 years ago

Hi, I'm that teammate. Strictly speaking you shouldn't even need to zero gradients at the start of optimizer_step, or have to overload optimizer_step. Between with torch.no_grad(): and Tensor.detach(), you can prevent unwanted gradient accumulation.

ayberkydn commented 4 years ago

Hi, teammates. I was getting good results as I added a zero_grad before my training loop but the problem is, the training loop should not include these tricks for the generic training regimes like GAN in the first place. The mentality of Lightning is to just define a loss with respect to your models and leave that kind of boilerplate to the framework, that's why I thought it would be better to submit a PR instead of hacking my training loop or advising it in the docs as I mentioned in the PR message :)

williamFalcon commented 4 years ago

can you guys post pseudocode for both versions to see what you mean more clearly?

ayberkydn commented 4 years ago
def training_step(self, batch, batch_nb, optimizer_idx):
        self.zero_grad()
        x_data = batch
        batch_size = x_data.shape[0]
        z = torch.randn(batch_size, 100).cuda()
        x_gen = self.G(z)
        real_labels = torch.ones(batch_size).cuda() 
        fake_labels = torch.zeros(batch_size).cuda() 

        if optimizer_idx == 0:           

            D_real = self.D(x_data)
            D_fake = self.D(x_gen)

            real_loss = self.loss(D_real, real_labels)
            fake_loss = self.loss(D_fake, fake_labels)
            d_loss = real_loss + fake_loss

            return {
                'loss': d_loss,
                'progress_bar': {'d_loss': d_loss}
                }

        if optimizer_idx == 1:

            D_fake = self.D(x_gen)
            g_loss = self.loss(D_fake, real_labels)

            return {
                'loss': g_loss,
                'progress_bar': {'g_loss': g_loss}
                }

this code works but it fails without self.zero_grad()

dbarnhart commented 4 years ago

Yes, it would. Make the following changes to fix it:

Prevent gradients accumulating on the generator during the discriminator pass: D_fake = self.D(x_gen) -> D_fake = self.D(x_gen.detach())

Prevent gradients accumulating on the discriminator during the generator pass: D_fake = self.D(x_gen) ->

with torch.no_grad():
    D_fake = self.D(x_gen)

Hope that helps.

ayberkydn commented 4 years ago

I've tried these and it works. I mean don't get me wrong, it's a really smart and elegant way to handle things, it even speeds up the code by ~10% but I think people shouldn't think about what gradients to prevent and what to let when training different models when it can be solved by a simple change at the high level that covers most of the cases. I really can't think of a case why one needs to hold any gradients after optimizer step. Also, your no_grad method would still be relevant after this PR.

williamFalcon commented 4 years ago

Ok, sounds like we actually want to make sure users think about this when training GANs. But totally agree that it’s a bit annoying.

Let’s make these changes:

  1. no change to the code
  2. update docs to help users understand this.

If it turns out that other users run into this as well, we can make it the default.

thank you all for a great discussion!

kwanUm commented 4 years ago

Looking at the gan.py example - does line 121 there (https://fburl.com/p8jnvbpf) needs to be updated to include with 'torch.no_grad():' beforehand, as suggested by @dbarnhart?

It seems like currently the discriminator is accumulating gradients from both the generator and the discriminator phases of the training step.

ayberkydn commented 4 years ago

Yes, it is currently incorrect but I think it won't be an issue after #603 is done.

kwanUm commented 4 years ago

I've attempted to include no_grad() at the generator step, but using no_grad() removes the gradients history entirely which in effect cuts the feedback from the discriminator when training the generator model.

The alternatives I came up with were either wrapping the discriminator call with "for param in self.discriminator.parameters(): param.requires_grad = False", or call discriminator.zero_grad() before returning the discriminator loss (faster): I've submitted fix #666 in the meantime to fix the GAN example, as it seems people are blocked on training GANs with PTL due to the same issue (see #557).

ayberkydn commented 4 years ago

I've attempted to include no_grad() at the generator step, but using no_grad() removes the gradients history entirely which in effect cuts the feedback from the discriminator when training the generator model.

The alternatives I came up with were either wrapping the discriminator call with "for param in self.discriminator.parameters(): param.requires_grad = False", or call discriminator.zero_grad() before returning the discriminator loss (faster): I've submitted fix #666 in the meantime to fix the GAN example, as it seems people are blocked on training GANs with PTL due to the same issue (see #557).

Gradient resetting would prevent gradient accumulation option.

lmartak commented 4 years ago

By the way, on_before_zero_grad is not seem to be called anywhere right now.

This still holds.

phongnhhn92 commented 4 years ago

Hi, I am trying to train a GAN model using pytorch lightining and I found this example from this repo. https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/generative_adversarial_net.py

I have read above comment about having to add the line self.zero_grad() first. However, I can not find this line in the above example. Is that still necessary to add it ?

I found that in the above example, u guys have used detach() function in training the Discriminator. https://github.com/PyTorchLightning/pytorch-lightning/blob/5c35db94fa338b8720175541a3b0dc872d963bdf/pl_examples/domain_templates/generative_adversarial_net.py#L152 Does it mean if I use detach() there then there is no need to add the line self.zero_grad() first ? Please clarify. Thanks

kushalchordiya216 commented 4 years ago

Yes, it would. Make the following changes to fix it:

Prevent gradients accumulating on the generator during the discriminator pass: D_fake = self.D(x_gen) -> D_fake = self.D(x_gen.detach())

Prevent gradients accumulating on the discriminator during the generator pass: D_fake = self.D(x_gen) ->

with torch.no_grad():
    D_fake = self.D(x_gen)

Hope that helps.

Hey, I had a question, shouldn't it technically work just as well (or perhaps a bit faster even), if the discriminator train step also involves a torch.no_grad() part instead of detaching the tensor, either way, the computations done in the generator are detached from the computation graph, but with no_grad, less memory will be used, right?

kracwarlock commented 4 years ago
with torch.no_grad():
    D_fake = self.D(x_gen)

this should not work unless backward() is also called within with torch.no_grad():, right?

williamFalcon commented 4 years ago

Here's a better example. We need to remove those and just point to bolts.

https://pytorch-lightning-bolts.readthedocs.io/en/latest/gans.html#basic-gan

kracwarlock commented 4 years ago

Yes, with the PR merged things will work. I was only trying to discuss the workaround mentioned earlier :)