AntixK / PyTorch-VAE

A Collection of Variational Autoencoders (VAE) in PyTorch.
Apache License 2.0
6.46k stars 1.05k forks source link

In factor vae: an inplace problem occurred when using loss.backward #14

Open anewusername77 opened 3 years ago

anewusername77 commented 3 years ago

the error message is:

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4096, 6]], which is output 0 of TBackward, is at version 2; expected version 1 instead.

it seems to be the problem that a term(self.D_z_reserve) used in D_tc_loss calculated at vae_loss stage was modified somehow.

D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) + F.cross_entropy(D_z_perm, true_labels))

giving details: I calculated and updated vae loss first, like:

 self.optim_VAE.zero_grad()
 vae_loss.backward(retain_graph=True)
 self.optim_VAE.step()

then when updating discriminator:

z = z.detach()
z_perm = self.permute_latent(z)
D_z_perm = self.D(z_perm)
D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) + F.cross_entropy(D_z_perm, true_labels))

self.optim_D.zero_grad()
D_tc_loss.backward()
self.optim_D.step()

the error message occurs as discribed at beginning.

when I delete term F.cross_entropy(self.D_z_reserve, false_labels) in D_tc_loss, or change D_tc_loss into

D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve.detach(), false_labels) + F.cross_entropy(D_z_perm, true_labels))

everything goes alright. but I'm not sure if use .detach() here is fine, and wondering what exact problem it is, waiting for you reply, thanks a lot.

AntixK commented 3 years ago

The FactorVAE as implemented here has this issue. I believe it is because of the way PyTorch Lightning calls the loss_function as the self.D_Z_reserve variable is updated by the model keeps track of the gradients for it before it was again updated.

I think this can be easily rectified with the latent pytorch lightning version, where they have improved a lot of stuff.

Using detach() removes any gradients that has been tracked. So when you call backward(), it does nothing to the loss that uses self.D_z_reserve.detach(). So, I wouldn't recommend using that.

federicoromeo commented 1 year ago

Still having this issue. All other implementations have the same issue.

Has anyone solved it? Is it right to detach as suggested?

The FactorVAE as implemented here has this issue. I believe it is because of the way PyTorch Lightning calls the loss_function as the self.D_Z_reserve variable is updated by the model keeps track of the gradients for it before it was again updated.

I think this can be easily rectified with the latent pytorch lightning version, where they have improved a lot of stuff.

Using detach() removes any gradients that has been tracked. So when you call backward(), it does nothing to the loss that uses self.D_z_reserve.detach(). So, I wouldn't recommend using that.