clementchadebec / benchmark_VAE

Unifying Variational Autoencoder (VAE) implementations in Pytorch (NeurIPS 2022)
Apache License 2.0
1.77k stars 161 forks source link

About training vae for discrete-valued data #150

Open Melon-Xu opened 1 month ago

Melon-Xu commented 1 month ago

Thank you for your excellent work! It is really helpful and convenient to use.

I have some image data with discrete pixel values, for example, only 0 and 255. I tried training a VAE on it but the loss is really huge and the reconstruction is not that good.

Do you have any suggestions on training VAE for this kind of data, like model selection, and hyper-parameter selections?

Thank you very much!

clementchadebec commented 1 month ago

Hello @Melon-Xu,

Thank you for the kind words. Can you try to rescale the values in the range [0, 1]?

Best,

Clément

Melon-Xu commented 1 month ago

Thank you for your reply. I will try it. Another question is that, currently I have trained the vae model on mnist, how do I test the reconstruction ability? Is there a script that tests the reconstruction as shown in the readme? Thank you very much!

Melon-Xu commented 1 month ago

Hi, Clément,

I tried to normalize the value from {0, 255} to {0, 1}, but the training loss is really high: around 3500 after converging. My batch size is 8 because my training set is very small, and my patch size is 256*256. The reconstruction performance is really bad, just look like this: image

Do you have any suggestions on it? Thank you very much!

clementchadebec commented 3 weeks ago

Hello @Melon-Xu

Thanks for sharing those results. It seems indeed that your data is very sparse (only a few points are non-zero) which may explain the loss value. Can you share the encoding-decoding architecture you are currently using? I have also experienced that VAE with simple architectures may struggle with this type of data when training a VAE on face landmarks. Another approach would consist in changing the data a bit to make them look like a continuous distribution by smoothing the values near the non-zero points (to make it look like a mixture of Gaussian, for instance). This would make the task easier, and at prediction time, you can postprocess the output of the VAE to get the final points.

I hope this helps,

Clément