NVlabs / NVAE

The Official PyTorch Implementation of "NVAE: A Deep Hierarchical Variational Autoencoder" (NeurIPS 2020 spotlight paper)
https://arxiv.org/abs/2007.03898
Other
999 stars 163 forks source link

Understanding the relationship between the code and the paper #4

Closed siavashk closed 3 years ago

siavashk commented 3 years ago

I think this work is likely to win the outstanding paper award. I am looking forward to the oral presentation.

First of all, the code works out-of-the-box. I am currently training on MNIST and the results are reasonable. I do have a couple of questions, which I think would benefit others as well. Some of the questions might be trivial, given that I am not an expert in VAEs.

  1. From the paper, Figure 2 (caption): "... and h is a trainable parameter": The description of h does not appear again in the paper (I might have missed it). What is it? Does it correspond to self.prior_ftr0 in line147?

  2. What are the magic numbers (multiplication and subtraction) in line 334? Is this basically transforming the intensity range [0.0, 1.0] to [-1.0, 1.0]?

  3. Do combiner_enc cells in line 344 correspond to the red ⊕ symbols along the encoder path in the in Figure 2?

  4. What does the enc0 function (or equivalently ftr) in line 355 represent? Is this the initial diamond residual layer that immediately follows x in Figure 2?

  5. What is the function of pre-processing layers and more precisely down_pre and normal_pre in the init_pre_process function in line 201? I can tell they have something to do with the bottom-up and top-down paths in Figure 2 but I am not sure what they do.

  6. Similar question to 5 regarding "post-processing" layers. What is their function? Why do you use them?

  7. Would it be fair to say that the latent representation of an image using the encoder network would be the mean of each residual normal distribution? If that is the case, would this essentially be mu_q in line 394 (corresponding to z1 in the decoder path) and z in line 396 for z2 and subsequent latents?

Thank you for sharing your code. I wish you the best.

arash-vahdat commented 3 years ago

Hi Siavash,

I am glad that this codebase is helpful. Regarding your questions:

  1. Yes, self.prior_ftr0 is the same as h in the paper. Sorry for the poor choice of the variable name.

  2. You guessed it right. We just want to make sure that input to the network is a bit normalized.

  3. Yes combiner encoder corresponds to the red ⊕ symbols along the encoder path.

  4. No, enc0 is actually the last layer before producing a distribution over z_0. In the code, we number z_s from 0 and z_0 is the one at the top of Fig. 2.

  5. pre-processing layers are the first diamond after x in the encoder. We just apply a few residual cells to x before starting the main trunk of the encoder. In these residual cells, we may reduce spatial dimensions of x or we may just extract a representation that is used in the bottom-up network.

  6. post-processing layers correspond to the last diamond before x in FIg. 2 b. You can think of them as some residual in the decoder that apply some non-linearity before generating x.

  7. I haven't used NVAE for representation learning. I am not sure what would be the best part to extract representation for a downstream task. mu_q's or z's are usually important for reconstruction. I would probably start with the last layer of the bottom-up network to extract representations from x.

I hope this helps.

siavashk commented 3 years ago

Thank you.