This repository contains the Python implementation for our generative clustering method VaDE.
Further details about VaDE can be found in our paper:
Replace keras/engine/training.py
by training.py
in this repository!!
(The modification version of keras/engine/training.py
enables the simultaneous updating of the gmm parameters and the network parameters in our model.)
To train the VaDE model on the MNIST, Reuters, HHAR datasets:
python ./VaDE.py db
db can be one of mnist,reuters10k,har.
To achieve the 94.46% clustering accuracy on the MNIST dataset and generate the class-specified digits (Note that: unlike Conditional GAN, we do not use any supervised information during training):
python ./VaDE_test_mnist.py
To achieve the 79.38% clustering accuracy on the Reuters(685K) dataset:
cd $VaDE_ROOT/dataset/reuters
./get_data.sh
cd $VaDE_ROOT
python ./VaDE_test_reuters_all.py
Note: the data preprocessing code for the Reuters dataset is taken from (https://github.com/piiswrong/dec).
(DCGAN-like network architecture)
1-6 rows: 1.black/short hair, man; 2.black/long hair, woman; 3.gold/long hair, woman; 4.bald, sunglasses, man; 5.left side face, woman; 6.right side face, woman;
Interpolation between cluster centers in latent space
Vector arithmetic in latent space:right + left = front