NVlabs / NVAE

The Official PyTorch Implementation of "NVAE: A Deep Hierarchical Variational Autoencoder" (NeurIPS 2020 spotlight paper)
https://arxiv.org/abs/2007.03898
Other
999 stars 163 forks source link

Importance of pre processing the Gaussian parameters ? #27

Closed BimDav closed 3 years ago

BimDav commented 3 years ago

Hi, thank you for your outstanding work in making VAEs great again !

My question is about the pre processing of Gaussian parameters in distributions.py:

def soft_clamp5(x: torch.Tensor):
        return x.div_(5.).tanh_().mul(5.)    #  5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5]

[...]

        self.mu = soft_clamp5(mu)
        log_sigma = soft_clamp5(log_sigma)
        self.sigma = torch.exp(log_sigma) + 1e-2

I don't think this is discussed in the paper, what is the role of this pre processing ? It seems to be linked with the model's stability when I remove it. Do you have results on the relationship between this and the other stabilization methods discussed in the paper ?

Thank you

arash-vahdat commented 3 years ago

Yeah, we added this with the hope of more model stability. It doesn't completely fix the stability issues but it did help a little bit. This processing ensures that mu and log_sigma stay in [-5, 5] so that the KL divergence doesn't blow up.

BimDav commented 3 years ago

Thank you. Are all experiments in the paper done with this pre processing ? I don't really understand the Unbounded KL problem if the parameters are clamped: isn't the KL bounded because of this ?

arash-vahdat commented 3 years ago

Yes, all experiments are done with this.

KL per latent variable is bounded but it can be still large for each variable. We have many latent variables in the model, and when instability happens, with a small mismatch between encoder and prior, these KL values per latent variable add up together and become extremely large.

BimDav commented 3 years ago

Thank you