psanch21 / VAE-GMVAE

This repository contains the implementation of the VAE and Gaussian Mixture VAE using TensorFlow and several network architectures
Apache License 2.0
206 stars 32 forks source link

Hyperparameters of GMVAE #4

Closed aihanb closed 5 years ago

aihanb commented 5 years ago

Hey psanch21!

I just run the code: python3 GMVAE_main.py --model_type=2 --dataset_name=MNIST --sigma=0.001 --z_dim=8 --w_dim=2 --K_clusters=8 --hidden_dim=64 --num_layers=2 --epochs=20 --batch_size=32 --drop_prob=0.3 --l_rate=0.01 --train=1 --results=1 --plot=1 --restore=1 --early_stopping=1

But I can not get the same results of GMVAE as you put on the Github. The images of the results are greatly different. Are Hyperparameters of GMVAE wrong? Could you give me some suggestions?

Thx a lot.

psanch21 commented 5 years ago

Hi aihanb!

That specific example was just to show its usage, it does not correspond to the results provided. In order to get something similar, you should at least increase the number of epochs (20 epochs is not sufficient steps of stochastic gradient descent) and increase the number of clusters to 10 (K_clusters=10), since we know beforehand there are at least 10 different classes.

Hope this helps!

aihanb commented 5 years ago

Thanks for your reply! I tried the value of hyperparameters that you mentioned in the bachelor thesis "Table 5.5 Experiment 2", which z=10, w=[2,10,20], K=10, Hidden layers=3, Hidden dimension=128. Yep, the results look better. But still one question, I increased the number of batch_size to 128 or 256, also the number of epochs to 200(it will not reach before overfitting), but the scatter plot of z and w do not look very good as Fig. 6.4&6.5. Could you tell me which number of batch_size or something else that you used? Or you just used the larger MNIST Dataset? I'm new on this, sorry my question seems silly.

psanch21 commented 5 years ago

Happy to hear your results look better! Regarding your question, I used the MNIST dataset thus results for other datasets may vary. You should validate the different hyperparameters. In relation to the batch_size, I would say the larger the better as long as you can store it on your computer. Cheers!