NVlabs / NVAE

The Official PyTorch Implementation of "NVAE: A Deep Hierarchical Variational Autoencoder" (NeurIPS 2020 spotlight paper)
https://arxiv.org/abs/2007.03898
Other
1.02k stars 164 forks source link

Tips for latent space manipulation and interpolation? #10

Closed loc-trinh closed 3 years ago

loc-trinh commented 3 years ago

Hi there, I was wondering if I can ask a question and get a bit of help about the representation power of NVAE. I would really appreciate it.

For one scenario, if I have two people, A, and B, and I get their representations, Z_a, and Z_b (all 25 Z's for 3 scales), I can somewhat change from person A to person B, by doing an interpolation such that new Z_a = (1-alpha) Z_a + (alpha) Z_b and decode, though the representation is not too smooth and alpha needs to be high (0.8) to see a change. I was wondering if the authors have tried something like this and see something similar?

Now, I only have Z_a, encoded from a person A. But unfortunately, I cannot modify Z_a by any means, to get a different person C, no matter what I add to the Z_a. If I add a large direction to it, I get a grainier/noisier image, but not a new person C.

I have also tried disabling flows but no luck. I feel that perhaps the mu and sigma coming from the encoder is too strong so that I cannot change the Z to get a new person (bc there's a part where NVAE concat the encoding and decoding mu,sigma). But then again I'm not sure.

Last point. in sampling, I see that we generate a new z_0 by sampling from N(0,1), and after going through the hierarchical Guassian, we get a nice new image. But let say I replace the z_0 by the z_0 of the encoded Z's of the person A, i get a junky/bad/trippy image, so I am at a loss bc the z_0 of a person A going through the hierarchical NVAE does not yield a person image.

loc-trinh commented 3 years ago

Hi again, I wanted to ask a follow up :( I was wondering if you guys have tried training NVAE but cutting out the red branches in the encoder. For examples, in model.py, line 396 dist = Normal(mu_p + mu_q, log_sig_p + log_sig_q) if self.res_dist else Normal(mu_q, log_sig_q) becomes dist = Normal(mu_p, log_sig_p)

arash-vahdat commented 3 years ago

Hi @loc-trinh

Regarding your first question, I would recommend generating the interpolations by traversing in the epsilon space. For doing this, you have to replace any call to sample() in the Normal class with sample_given_eps(). If you smoothly change eps, your images will also smoothly change: https://github.com/NVlabs/NVAE/blob/master/distributions.py#L38

Regarding your next question, we cannot remove those red lines. Without them, your encoder becomes independent from input.