LTH14 / mar

PyTorch implementation of MAR+DiffLoss https://arxiv.org/abs/2406.11838
MIT License
979 stars 53 forks source link

Where is the 0.2325 from ? #2

Closed LZY-the-boys closed 3 months ago

LZY-the-boys commented 3 months ago

Hi,

Thank you for your excellent work. I am very interested in the use of 0.2325 in your VAE implementation:

x = posterior.sample().mul_(0.2325) / vae.decode(sampled_tokens / 0.2325)

Could you please explain the origin of the 0.2325 value? Was it calculated from the mean and logvar of a custom VAE latent space on ImageNet? If possible, could you also provide the code for this calculation?

Thank you for your assistance!

LTH14 commented 3 months ago

Thanks for your interest! I just computed the std of the latents from the VAE encoder. The code is simply scale_factor = 1 / torch.std(vae.encode(x).sample()). In my computation x is the entire training set and I store the latents to compute the standard deviation.