Closed jprachir closed 2 years ago
Hi @jprachir,
Thank you for opening this issue and sorry for the late reply. I will try to make my answer as detailed as possible if future questions arise.
The samplers are designed the same way the models are. This means that for any sampling technique (GMM, MAF ...) you will find in the pythae.samplers
a folder containing a nameofsampler_config.py
file with a dataclass where the parameters of the sampling scheme are defined and a nameofsampler_sampler.py
where the actual sampling technique is implemented.
Then, the samplers can be used with any suited model to generate new data. A pythae.samplers
instance needs as input a trained model from which you want to generate. This is a required input to be given to your sampler. Optionally, you can also provide a custom configuration of your sampler when building it to use different sampler hyper-parameters.
For instance, let's say we have trained a vae and want to sample for it using a pythae.samplers.GaussianMixtureSampler
. All you have to do to build your sampler is the following:
>>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
>>> # Define your sampler configuration
... gmm_sampler_config = GaussianMixtureSamplerConfig(
... n_components=12
... )
>>> # Build your sampler
... my_samper = GaussianMixtureSampler(
... sampler_config=gmm_sampler_config,
... model=my_trained_vae # A trained `pythae.models` instance
... )
This works the same for any other sampler that is implemented in pythae
sample
methodOnce, your sampler has been instantiated, you can use it to generate new data pretty easily using the sample
method that will actually save and/or return the generated samples from your trained model. The sample
method is the one that the sampler uses to generate the embeddings (latent variables) that will then be passed to the decoder
of the model we are sampling from to get the generated samples in the data space.
For instance, if we take the simplest example with the pythae.samplers.NormalSampler
, you can see that in the sample
method we:
Generate the embeddings with the defined technique (here a N(0, 1)) with https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/normal_sampling/normal_sampler.py#L54
Then, we pass them to the decoder of the model we are sampling from with https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/normal_sampling/normal_sampler.py#L55
The other samplers work the same but may use fancier scheme to generate the embedding in 1. In particular, some of them needs to be fitted before generating.
fit
methodAs you may have noticed, the previous sample
method only takes arguments that are relevant to the generation function itself (number of samples, batch size, whether to return the generated samples or not etc). However, some samplers need to be fitted before we call the sample
method. For instance, a GaussianMixtureSample
will first need to fit a Gaussian Mixture on the embeddings learned by your model. Similarly, the TwoStageVAESampler
, MAFSampler
, IAFSampler
or PixelCNNSampler
instances will require you to call the fit
method before sampling with the sample
function since they require a model (VAE, Normalizing Flow, Autoregressive Flow) to be fitted on the learned embeddings first as well.
Hence, if you want to sample for these samplers and you do not call the fit
method before sample
you should get the following error
my_samper = GaussianMixtureSampler(
... sampler_config=gmm_sampler_config,
... model=my_trained_vae
... )
>>> my_sampler.sample(10)
... Traceback (most recent call last):
... File "<stdin>", line 1, in <module>
... File "/home/clement/Documents/these/implem/benchmark_VAE/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py", line 128, in sample
... raise ArithmeticError(
... ArithmeticError: The sampler needs to be fitted by calling smapler.fit() method before sampling.
While the correct usage is
>>> my_samper = GaussianMixtureSampler(
... sampler_config=gmm_sampler_config,
... model=my_trained_vae
... )
>>> # fit the sampler
>>> gmm_sampler.fit(train_dataset)
>>> # Generate samples
>>> gen_data = my_samper.sample(
... num_samples=50,
... batch_size=10,
... output_dir=None,
... return_gen=True
... )
Now, what happens in the fit
method. Calling this method means that you need your sampler to be fitted with some elements coming for your trained model (for instance the learned embeddings/latent variables). This is why, when you call this method you will be able to pass your train/eval data to retrieve the embeddings for instance.
If we look at the example of the GaussianMixtureSampler
again, you will see that in the fit
method we:
sample
method.
https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/gaussian_mixture/gaussian_mixture_sampler.py#L101Then as explained in the previous section for the NormalSampler
, the GMM model is used in the sample
method to
The other samplers work the same.
For instance, the MAFSampler
will
Retrieve the needed train and eval embeddings from our trained model https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L93-L95 https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L126-L128
Then use these embeddings to fit the normalizing flow https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L140-L147
The flow model is assigned to the sampler for further use in the sample
method
https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L149-L151
Then, the flow model is used in the sample
method to
Generate embeddings https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L194-L195
Retrieve the generated samples in the data space. https://github.com/clementchadebec/benchmark_VAE/blob/a9a5388785b74ea9b01f07ee92113a5d903b5d0a/src/pythae/samplers/maf_sampler/maf_sampler.py#L196
I hope this helps you better apprehending how the samplers work. In any case, do not hesitate if you have any other questions or need me to clarify some points :)
Best,
Clément
Hi @clementchadebec, Thanks for the detailed reply. It certainly demonstrates the gist of sampler components. I would use this issue to clarify questions related to samplers in the future.
Best, Prachi
Hi Clément: Great work on introducing the VAE-oriented library! You have made it more modular like predefined models, pipelines, and so forth. Can you share brief details on how the sampler works under the hood for generations?
Prachi