bayesiains / nflows

Normalizing flows in PyTorch
MIT License
851 stars 119 forks source link

Allow base distributions which lack context #33

Closed MilesCranmer closed 2 years ago

MilesCranmer commented 3 years ago

This allows one to use flow base distributions which lack a context keyword argument, like anything in torch.distributions.Independent, which is used for BoxUniform. It basically just checks at initialization if context is an available keyword argument to self._distribution.log_prob, and uses that to decide whether to pass context or not. Otherwise such distributions will give an error.

@arturbekasov @imurray

This also fixes #17

Cheers, Miles

MilesCranmer commented 3 years ago

edit: I just fixed sampling for context-less base distributions. I think the default behaviour is that it should still use the shape of the context to figure out the number of samples to create. Before the latest change it was only using num_samples.

arturbekasov commented 2 years ago

A bit hacky, but nice in that we could use more PyTorch distributions out-of-the-box, without having to create wrappers.

mlkrock commented 2 years ago

Thanks @MilesCranmer @arturbekasov this is a nice addition, but I cannot get sampling to work. Maybe you see a quick fix?

from nflows import transforms, distributions, flows
import torch

# Define an invertible transformation.
transform = transforms.CompositeTransform([
    transforms.MaskedAffineAutoregressiveTransform(features=2, hidden_features=4),
    transforms.RandomPermutation(features=2)
])

# Define a base distribution.
base_distribution = torch.distributions.Independent(torch.distributions.Uniform(torch.zeros(2), torch.ones(2)), 1)

# Combine into a flow.
flow = flows.Flow(transform=transform, distribution=base_distribution)
flow.sample(1)

gives error message:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/nflows/distributions/base.py", line 70, in sample
    return self._sample(num_samples, context)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/nflows/flows/base.py", line 56, in _sample
    repeat_noise = self._distribution.sample(num_samples*embedded_context.shape[0])
AttributeError: 'NoneType' object has no attribute 'shape'