clementchadebec / benchmark_VAE

Unifying Variational Autoencoder (VAE) implementations in Pytorch (NeurIPS 2022)
Apache License 2.0
1.82k stars 163 forks source link

how Sampler works? #43

Closed jprachir closed 2 years ago

jprachir commented 2 years ago

Hi Clément: Great work on introducing the VAE-oriented library! You have made it more modular like predefined models, pipelines, and so forth. Can you share brief details on how the sampler works under the hood for generations?

Prachi

clementchadebec commented 2 years ago

Hi @jprachir,

Thank you for opening this issue and sorry for the late reply. I will try to make my answer as detailed as possible if future questions arise.

Samplers design

The samplers are designed the same way the models are. This means that for any sampling technique (GMM, MAF ...) you will find in the pythae.samplers a folder containing a nameofsampler_config.py file with a dataclass where the parameters of the sampling scheme are defined and a nameofsampler_sampler.py where the actual sampling technique is implemented.

Sampler definition

Then, the samplers can be used with any suited model to generate new data. A pythae.samplers instance needs as input a trained model from which you want to generate. This is a required input to be given to your sampler. Optionally, you can also provide a custom configuration of your sampler when building it to use different sampler hyper-parameters.

Example - GMM sampler

For instance, let's say we have trained a vae and want to sample for it using a pythae.samplers.GaussianMixtureSampler. All you have to do to build your sampler is the following:

>>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
>>> # Define your sampler configuration
... gmm_sampler_config = GaussianMixtureSamplerConfig(
... n_components=12
... )
>>> # Build your sampler
... my_samper = GaussianMixtureSampler(
... sampler_config=gmm_sampler_config,
... model=my_trained_vae # A trained `pythae.models` instance
... )

This works the same for any other sampler that is implemented in pythae

The sample method

Once, your sampler has been instantiated, you can use it to generate new data pretty easily using the sample method that will actually save and/or return the generated samples from your trained model. The sample method is the one that the sampler uses to generate the embeddings (latent variables) that will then be passed to the decoder of the model we are sampling from to get the generated samples in the data space.

Example - N(0,1) sampler

For instance, if we take the simplest example with the pythae.samplers.NormalSampler, you can see that in the sample method we:

  1. Generate the embeddings with the defined technique (here a N(0, 1)) with https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/normal_sampling/normal_sampler.py#L54

  2. Then, we pass them to the decoder of the model we are sampling from with https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/normal_sampling/normal_sampler.py#L55

The other samplers work the same but may use fancier scheme to generate the embedding in 1. In particular, some of them needs to be fitted before generating.

The fit method

As you may have noticed, the previous sample method only takes arguments that are relevant to the generation function itself (number of samples, batch size, whether to return the generated samples or not etc). However, some samplers need to be fitted before we call the sample method. For instance, a GaussianMixtureSample will first need to fit a Gaussian Mixture on the embeddings learned by your model. Similarly, the TwoStageVAESampler, MAFSampler, IAFSampler or PixelCNNSampler instances will require you to call the fit method before sampling with the sample function since they require a model (VAE, Normalizing Flow, Autoregressive Flow) to be fitted on the learned embeddings first as well.

Example - GMM sampler

Hence, if you want to sample for these samplers and you do not call the fit method before sample you should get the following error

my_samper = GaussianMixtureSampler(
... sampler_config=gmm_sampler_config,
... model=my_trained_vae
... )
>>> my_sampler.sample(10)
... Traceback (most recent call last):
...    File "<stdin>", line 1, in <module>
...    File "/home/clement/Documents/these/implem/benchmark_VAE/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py", line 128, in sample
...       raise ArithmeticError(
... ArithmeticError: The sampler needs to be fitted by calling smapler.fit() method before sampling.

While the correct usage is

>>> my_samper = GaussianMixtureSampler(
... sampler_config=gmm_sampler_config,
... model=my_trained_vae
... )
>>> # fit the sampler
>>> gmm_sampler.fit(train_dataset)
>>> # Generate samples
>>> gen_data = my_samper.sample(
... num_samples=50,
... batch_size=10,
... output_dir=None,
... return_gen=True
... )

Now, what happens in the fit method. Calling this method means that you need your sampler to be fitted with some elements coming for your trained model (for instance the learned embeddings/latent variables). This is why, when you call this method you will be able to pass your train/eval data to retrieve the embeddings for instance.

Example 1 - GMM sampler

If we look at the example of the GaussianMixtureSampler again, you will see that in the fit method we:

  1. Retrieve the needed train embeddings from our trained model https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py#L75
  2. Then, we use these embeddings to fit our Gaussian mixture with them https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py#L92-L99
  3. Finally, we assign to the sampler the GMM model to further use it in the sample method. https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py#L101

Then as explained in the previous section for the NormalSampler, the GMM model is used in the sample method to

  1. Generate embeddings https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py#L140-L144
  2. Retrieve the generated samples in the data space. https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py#L145

Example 2 - MAF sampler

The other samplers work the same.

For instance, the MAFSampler will

  1. Retrieve the needed train and eval embeddings from our trained model https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L93-L95 https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L126-L128

  2. Then use these embeddings to fit the normalizing flow https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L140-L147

  3. The flow model is assigned to the sampler for further use in the sample method https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L149-L151

Then, the flow model is used in the sample method to

  1. Generate embeddings https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L194-L195

  2. Retrieve the generated samples in the data space. https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L196

Conclusion

I hope this helps you better apprehending how the samplers work. In any case, do not hesitate if you have any other questions or need me to clarify some points :)

Best,

Clément

jprachir commented 2 years ago

Hi @clementchadebec, Thanks for the detailed reply. It certainly demonstrates the gist of sampler components. I would use this issue to clarify questions related to samplers in the future.

Best, Prachi