Closed DianeBouchacourt closed 3 years ago
With a bit more research, it comes from the call of optimizer.step() here https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/disvae/models/losses.py#L283
commenting it out works fine, but then how to do the optimization step of the model before training the discriminator?
Hi @DianeBouchacourt,
This seems to be an issue with pytorch > 1.4 (e.g. https://github.com/pytorch/pytorch/issues/39141#issuecomment-636881953 )
I.e. I'm currently modifying the weights before finishing computing the gradients, which is not allowed from pytorch 1.5.
I see 2 possibilities:
It should work now with Pytorch 1.7 (I only tested factor_mnist). Let me know if it doesn't and thank you for pointing the issue.
All good it works now, but doesn't it change the optimization dynamics ? Since now model parameters are not updated via optimizer.step() before latent_sample2 = model.sample_latent(data2) is called
I believe that it does slightly change the dynamics (in that the discriminator now tries to discriminate sample from "the previous step" but that shouldn't make a big difference). I don't see how to bypass that "issue" without computing the first sample twice (once for each optimizer)... But if you have a better suggestion I'm happy to change that.
What is important is that the gradient computation is correct, and I think it is because we are detaching the second sample z_perm = _permute_dims(latent_sample2).detach()
and so no gradients flow in the VAE when computing the gradients for the discriminator.
OK thanks for the quick reply !
When running python main.py factor_coloredmnist -x factor_coloredmnist on Python 3.8.5 I get the following error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256, 20]], which is output 0 of TBackward, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Coming from d_tc_loss.backward()
I tried replacing inplace=False in the leaky_relu of the discriminator without success. The error comes from calling F.cross_entropy(d_z, zeros) in d_tc_loss (the term F.cross_entropy(d_z_perm, ones) poses no problem).
Any help would be appreciated :)