daib13 / TwoStageVAE

230 stars 33 forks source link

Could you please elaborate on your loss function? #7

Open mmderakhshani opened 4 years ago

mmderakhshani commented 4 years ago

I am trying to reimplement your code in PyTorch and I need to know what is the difference between your loss function and the loss function regarding the vanilla VAE? Based on my experience, your KL-divergence formula is not correct or is not as same as the one that we see in the regular implementation of VAE and there is a subtle difference between them. Could you please explain it a little more? I also have a question about the self.gen_loss1. Could you please explain it?

Another thing which is very worthwhile to mention is that when I optimize the stage 1 network, by the end of the training, I received negative losses for loss_gen1 which I think they are related to self.loggamma_x. When I disabled it and left it constant "0", the loss values did not become negative.

daib13 commented 4 years ago

Hi @mmderakhshani

  1. KL loss. What do you mean by a regular implementation? The KL loss is exactly the same as equation 6 in "Tutorial on variational autoencoders". Note that the approximate posterior for the j-th dimension is N(\mu_j, \sigma_j^2). Some implementations may use N(\mu_j, \sigma_j). This could be a potential difference but it will make no difference in the generative performance.

  2. self.genloss1 = - log p\theta(x|z), where p_\theta(x|z) is a Gaussian distribution, i.e. N(x | \hat{x}, \gamma I). Then we have

self.gen_loss1 = \sum ( (x- \hat{x})^2 / \gamma / 2 - log \gamma ) + constant,

which is our implementation. Yes self.gen_loss1 could be negative. As we argued in our paper "Diagnosing and enhancing vae models", \gamma will converge to 0 and self.gen_loss1 will go to negative infinite when the objective function is globaly optimized. Of course you can fix self.loggamma_x to be 0. But this will make the reconstruction blurry. Intuitively speaking, as \hat{x} become exactly the same as x, meaning the model produces perfect reconstruction, the only term related to \gamma in the objective is -log \gamma, which will push \gamma to 0 and the objective to negative infinite.

mmderakhshani commented 4 years ago

Ok. Thanks for the paper you referred to and again thanks for your great explanation.

Could you please tell me how did you handle the negative loss case? Did you follow a kind of policy?

As another question, I have seen in your code that you used Adam Optimizer to optimize the parameters of the network and also at the beginning of each epoch, you changed the value of the learning rate of each parameter (some kind of decaying strategy). I think, as far as I know, Adam Optimizer, on its own, changes the value of the learning rate per each parameter based on some update rule. Could you please tell me why did you change the value of learning rate manually?

daib13 commented 4 years ago

@mmderakhshani

  1. About the negative loss. We just leave it negative. There is no need to force the loss to be positive.

  2. About the adam optimizer and the learning rate. I just randomly select an optimization strategy. I don't know much about optimization. Maybe there is no need to manually change the learning rate as you said. I am not sure which way is better.

chanshing commented 4 years ago

self.gen_loss1 = \sum ( (x- \hat{x})^2 / \gamma / 2 - log \gamma ) + constant,

I think the terms in the summation should be summing, instead of substracting:

\sum ( (x- \hat{x})^2 / \gamma / 2 + log \gamma )

daib13 commented 4 years ago

@chanshing yes you are correct. I made a typo in the response. Thanks for pointing this out.

mago876 commented 4 years ago

I'm having trouble with \gamma: In your code it seems that N(x | \hat{x}, \gamma^2 I) then self.gen_loss1 = \sum ( (x- \hat{x})^2 / \gamma^2 / 2 + log \gamma ) + constant, That's right?

daib13 commented 4 years ago

@mago876 yeah in the code we use N(x | \hat{x}, \gamma^2 I). In the paper and the discussion, we use N(x | \hat{x}, \gamma I). Sorry for the confusion. In our original paper draft, we used the same formulation as that in the code. But we changed it to the current version for convenience but didn’t change the code accordingly.