Closed ayberkydn closed 4 years ago
sorry for the trouble! good catch though. Thanks for the PR, we’ll merge it once tests pass :)
feel free to submit the PR when ready
Whoops, I just have failed to submit the PR. I think it is ready now. Appreciate your work very much. Thanks for the great library :)
hi, at work we train GANs successfully without any optimizer overrides. We also have the zero_grad before optimizer step
my teammate recommends that you put no_grad around your discriminator forward, and your runtime would be faster than calling zero_grad in the proposed way. Would that work for you?
Hi, I'm that teammate. Strictly speaking you shouldn't even need to zero gradients at the start of optimizer_step, or have to overload optimizer_step. Between with torch.no_grad():
and Tensor.detach()
, you can prevent unwanted gradient accumulation.
Hi, teammates. I was getting good results as I added a zero_grad before my training loop but the problem is, the training loop should not include these tricks for the generic training regimes like GAN in the first place. The mentality of Lightning is to just define a loss with respect to your models and leave that kind of boilerplate to the framework, that's why I thought it would be better to submit a PR instead of hacking my training loop or advising it in the docs as I mentioned in the PR message :)
can you guys post pseudocode for both versions to see what you mean more clearly?
def training_step(self, batch, batch_nb, optimizer_idx):
self.zero_grad()
x_data = batch
batch_size = x_data.shape[0]
z = torch.randn(batch_size, 100).cuda()
x_gen = self.G(z)
real_labels = torch.ones(batch_size).cuda()
fake_labels = torch.zeros(batch_size).cuda()
if optimizer_idx == 0:
D_real = self.D(x_data)
D_fake = self.D(x_gen)
real_loss = self.loss(D_real, real_labels)
fake_loss = self.loss(D_fake, fake_labels)
d_loss = real_loss + fake_loss
return {
'loss': d_loss,
'progress_bar': {'d_loss': d_loss}
}
if optimizer_idx == 1:
D_fake = self.D(x_gen)
g_loss = self.loss(D_fake, real_labels)
return {
'loss': g_loss,
'progress_bar': {'g_loss': g_loss}
}
this code works but it fails without self.zero_grad()
Yes, it would. Make the following changes to fix it:
Prevent gradients accumulating on the generator during the discriminator pass:
D_fake = self.D(x_gen)
-> D_fake = self.D(x_gen.detach())
Prevent gradients accumulating on the discriminator during the generator pass:
D_fake = self.D(x_gen)
->
with torch.no_grad():
D_fake = self.D(x_gen)
Hope that helps.
I've tried these and it works. I mean don't get me wrong, it's a really smart and elegant way to handle things, it even speeds up the code by ~10% but I think people shouldn't think about what gradients to prevent and what to let when training different models when it can be solved by a simple change at the high level that covers most of the cases. I really can't think of a case why one needs to hold any gradients after optimizer step. Also, your no_grad method would still be relevant after this PR.
Ok, sounds like we actually want to make sure users think about this when training GANs. But totally agree that it’s a bit annoying.
Let’s make these changes:
If it turns out that other users run into this as well, we can make it the default.
thank you all for a great discussion!
Looking at the gan.py example - does line 121 there (https://fburl.com/p8jnvbpf) needs to be updated to include with 'torch.no_grad():' beforehand, as suggested by @dbarnhart?
It seems like currently the discriminator is accumulating gradients from both the generator and the discriminator phases of the training step.
Yes, it is currently incorrect but I think it won't be an issue after #603 is done.
I've attempted to include no_grad() at the generator step, but using no_grad() removes the gradients history entirely which in effect cuts the feedback from the discriminator when training the generator model.
The alternatives I came up with were either wrapping the discriminator call with "for param in self.discriminator.parameters(): param.requires_grad = False", or call discriminator.zero_grad() before returning the discriminator loss (faster): I've submitted fix #666 in the meantime to fix the GAN example, as it seems people are blocked on training GANs with PTL due to the same issue (see #557).
I've attempted to include no_grad() at the generator step, but using no_grad() removes the gradients history entirely which in effect cuts the feedback from the discriminator when training the generator model.
The alternatives I came up with were either wrapping the discriminator call with "for param in self.discriminator.parameters(): param.requires_grad = False", or call discriminator.zero_grad() before returning the discriminator loss (faster): I've submitted fix #666 in the meantime to fix the GAN example, as it seems people are blocked on training GANs with PTL due to the same issue (see #557).
Gradient resetting would prevent gradient accumulation option.
By the way, on_before_zero_grad is not seem to be called anywhere right now.
This still holds.
Hi, I am trying to train a GAN model using pytorch lightining and I found this example from this repo. https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/generative_adversarial_net.py
I have read above comment about having to add the line self.zero_grad()
first. However, I can not find this line in the above example. Is that still necessary to add it ?
I found that in the above example, u guys have used detach() function in training the Discriminator.
https://github.com/PyTorchLightning/pytorch-lightning/blob/5c35db94fa338b8720175541a3b0dc872d963bdf/pl_examples/domain_templates/generative_adversarial_net.py#L152
Does it mean if I use detach() there then there is no need to add the line self.zero_grad()
first ? Please clarify. Thanks
Yes, it would. Make the following changes to fix it:
Prevent gradients accumulating on the generator during the discriminator pass:
D_fake = self.D(x_gen)
->D_fake = self.D(x_gen.detach())
Prevent gradients accumulating on the discriminator during the generator pass:
D_fake = self.D(x_gen)
->with torch.no_grad(): D_fake = self.D(x_gen)
Hope that helps.
Hey, I had a question, shouldn't it technically work just as well (or perhaps a bit faster even), if the discriminator train step also involves a torch.no_grad() part instead of detaching the tensor, either way, the computations done in the generator are detached from the computation graph, but with no_grad, less memory will be used, right?
with torch.no_grad():
D_fake = self.D(x_gen)
this should not work unless backward()
is also called within with torch.no_grad():
, right?
Here's a better example. We need to remove those and just point to bolts.
https://pytorch-lightning-bolts.readthedocs.io/en/latest/gans.html#basic-gan
Yes, with the PR merged things will work. I was only trying to discuss the workaround mentioned earlier :)
I was trying to train a DCGAN on my dataset but it wouldn't work in any means until I detach the training logic from Lightning and run the code without it. It was not working when my training logic is in Lightning module. I checked the gan examples in the docs and also multiple optimizer things. After 2 days of headaches, source code inspections and putting numerous print statements in the lightning source code, I found the culprit.
GAN training with Pytorch Lightning is simply broken. The culprit is only calling optimizer.zero_grad() after optimizer.step() since it clears the gradients of Generator or Discriminator only. Before the other network's weights are updated for, say, Generator; "loss.backward()" is called and it updates gradients for all parameters, but when optimizer.zero_grad() is called after the parameters are updated, only Generator's gradients are cleaned. So when it comes the Discriminator loss.backward(), leftover gradients are accumulated for Discriminator parameters and it just messes up with everything. Any kind of GAN training is impossible with this settings. That's why you can not find any GAN implementations with Pytorch Lightning on the internet.
Possible solutions -Putting a warning in the docs or on the console after detecting multiple optimizers are defined like "if you train GAN, don't forget to zero all gradients by overriding the optimizer_step method or just reset gradients in your training loop before returning loss dictionary". (would be weird, honestly)
calling zero_grad before calling backward in optimizer_closure (but it would mess with gradient accumulation, i suppose)
calling self.zero_grad() instead of optimizer.zero_grad() by default in the optimizer_step.
call zero_grad for all optimizers of the lightning module after gradient step by default so that the hook on_before_zero_grad is called whenever appropriate.
(By the way, on_before_zero_grad is not seem to be called anywhere right now. Maybe the issue can be fixed with a default behavior of that method, alternatively.)
I just sent a pull request which implements the last option and updates the docs.