Closed DianeBouchacourt closed 3 years ago
With a bit more research, it comes from the call of optimizer.step() here
commenting it out works fine, but then how to do the optimization step of the model before training the discriminator?
Hi @DianeBouchacourt,
This seems to be an issue with pytorch > 1.4 (e.g. )
I.e. I'm currently modifying the weights before finishing computing the gradients, which is not allowed from pytorch 1.5.
I see 2 possibilities:
It should work now with Pytorch 1.7 (I only tested factor_mnist). Let me know if it doesn't and thank you for pointing the issue.
All good it works now, but doesn't it change the optimization dynamics ? Since now model parameters are not updated via optimizer.step() before latent_sample2 = model.sample_latent(data2) is called
I believe that it does slightly change the dynamics (in that the discriminator now tries to discriminate sample from "the previous step" but that shouldn't make a big difference). I don't see how to bypass that "issue" without computing the first sample twice (once for each optimizer)... But if you have a better suggestion I'm happy to change that.
What is important is that the gradient computation is correct, and I think it is because we are detaching the second sample z_perm = _permute_dims(latent_sample2).detach()
and so no gradients flow in the VAE when computing the gradients for the discriminator.
OK thanks for the quick reply !
When running python factor_coloredmnist -x factor_coloredmnist on Python 3.8.5 I get the following error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256, 20]], which is output 0 of TBackward, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Coming from d_tc_loss.backward()
I tried replacing inplace=False in the leaky_relu of the discriminator without success. The error comes from calling F.cross_entropy(d_z, zeros) in d_tc_loss (the term F.cross_entropy(d_z_perm, ones) poses no problem).
Any help would be appreciated :)