bayesiains / nflows

Normalizing flows in PyTorch
MIT License
845 stars 118 forks source link

Fast sampling for varying context #54

Closed francesco-vaselli closed 2 years ago

francesco-vaselli commented 2 years ago

Hello there, and thank you all for your work on this package-it has been tremendously helpful.

I am opening this issue because I would like to know if there is a way to perform fast generation when each sample has a different context. To this date, sampling multiple points from the same context is straightforward, i.e. if y is a vector with six elements then

flow.sample(10000, context=y.view(-1, 6))

is quite fast, but samples are all conditioned on the same 6 context values. I have a vector y of shape (10000, 6) and I would like to sample 10000 new points, each one conditioned on a different set of values of the y array. At the moment the best I could manage was something like:

`samples = []

for i in range(0, 10000):

curr_sample = flow.sample(1, context=y[i].view(-1, 6))

curr_sample = curr_sample.detach().cpu().numpy()

curr_sample = np.squeeze(curr_sample, axis=0)

samples.append(curr_sample)`

However, being a bare Python for loop, the process is quite slow (30 minutes for 1e4 samples). Is there a way to speedup the sampling process? Or am I missing some specific way to pass the arguments to the sample method? I am more than willing to work on a pull request for this problem if you can provide me with some guidance. Thanks!

leejielong commented 2 years ago

Hi Francesco, actually the package is already able to handle different contexts. All you have to do is to provide a context of shape (N, 6), where N is the number of unique contexts that you would like sample from. The n_sample argument determines the number of samples to draw from each unique context. Here's an example:

context = torch.tensor(np.ones([10,6]))
num_samples = 1000 #draw 1000 samples from each of the 10 unique context
samples = flow.sample(num_samples, context=context)

The shape of the output samples will be (1000,10,2), assuming the prior distribution has 2 channels.

francesco-vaselli commented 2 years ago

Hey there! Thank you so much for taking the time to answer, you totally made my day. So it turns out that I had actually tried something like that but I had stumbled upon CUDA out-of-memory errors, as I insisted in plugging in 1e4 different contexts at a time... Trying with a smaller number did the job!!! Many thanks!