anvoynov / GANLatentDiscovery

The authors official implementation of Unsupervised Discovery of Interpretable Directions in the GAN Latent Space
416 stars 52 forks source link

Any tips for wrapping my trained generator? #12

Closed jbmaxwell closed 3 years ago

jbmaxwell commented 4 years ago

First off, very cool work!

I've been working with a modified ClusterGAN and am struggling a bit with wrapping it for use with your approach? The generator is a pretty basic 3-layer convolutional model with spectral normalization. The trained generator of the ClusterGAN operates as a conditional GAN, so it takes two parameters: zn (the latent) and zc, the class label vector (basically a one-hot encoded class vector—though you can give it partial class values, for interpolation).

Any guidelines on how I might go about setting it up for exploration with your model would be greatly appreciated! Particularly, how to handle the conditional/class vector. I see that the code indicates conditional generation, but I'm not clear on how target_classes and mixed_classes are handled.

anvoynov commented 4 years ago

Hi @jbmaxwell thank you!

Take a look at the BigGAN wrapper as it is conditional as well. You should:

overall, that's all you need.

jbmaxwell commented 4 years ago

Okay, awesome—I managed to get it training. How would this be utilized in a deployment context? Looking at visualization.interpolate()it looks like the deformator is used to create an array of offsets, which are added to a given latent. Correct?

jbmaxwell commented 4 years ago

One last thing I'm somewhat confused by... The input size for the deformator seems to be based on my GAN's latent variable (z) size (e.g., 256). Since my data is somewhat MNIST-like, I've trained for 64 directions (in trainer.py Params), but I'm not sure what the connection is between the deformator's input and the directions count? It seems intuitive (to me) to expect an input size that correlates with the number of directions learned, but I'm not sure how this could be the case. So, how do I target/parametrize a particular direction?

The charts generated after training show 256 rows (in groups of 20), along which I can see the transformation associated with that index. But how does that index relate to the directions?

Thanks in advance for any light you can shed!

anvoynov commented 4 years ago

By default, once you specify the 'directions_count' they are the leading directions from 0 to 'directions_count' in the deformator's output. So basically you should check the first 64 of 256 rows in the generated charts.

jbmaxwell commented 4 years ago

Excellent, thanks!