1Konny / FactorVAE

Pytorch implementation of FactorVAE proposed in Disentangling by Factorising(http://arxiv.org/abs/1802.05983)
MIT License
262 stars 48 forks source link

Inplace problem when calling D_tc_loss.backward() #15

Open anewusername77 opened 4 years ago

anewusername77 commented 4 years ago

the error message is:

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4096, 6]], which is output 0 of TBackward, is at version 2; expected version 1 instead.

it seems to be the problem that a term(D_z) used in D_tc_loss calculated at vae_loss stage was modified somehow.

D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))

giving details: I calculated and updated vae loss first:

 vae_loss.backward(retain_graph=True)
 self.optim_VAE.step()

then when updating discriminator:

D_z_pperm = self.D(z_pperm)
D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_pperm, ones))

self.optim_D.zero_grad()
D_tc_loss.backward()
self.optim_D.step()

the error message occurs as discribed at beginning.

when I delete term F.cross_entropy(D_z, false_labels) in D_tc_loss, or change D_tc_loss into

D_tc_loss = 0.5*(F.cross_entropy(D_z.detach(), zeros) + F.cross_entropy(D_z_pperm, ones))

everything goes alright. but I'm not so sure if using `.detach()' here is alright, and wondering what exact problem it is waiting for your reply, thanks a lot.

FrankBrongers commented 3 years ago

I ran into the same issue and think it's because the optimizer of the VAE performs a step before the a backward pass is done for the Discriminator, altering the weights which are necessary for the dependency graph of z.

The solution would then be to put self.optim_VAE.step() after D_tc_loss.backward(), which seems to work for me. If we detach D_z in the D_tc_loss calculation, I would image the backward pass doing nothing nor the optimizer step, which would result in the discriminator not learning.