Closed MilesCranmer closed 2 years ago
edit: I just fixed sampling for context-less base distributions. I think the default behaviour is that it should still use the shape of the context to figure out the number of samples to create. Before the latest change it was only using num_samples
.
A bit hacky, but nice in that we could use more PyTorch distributions out-of-the-box, without having to create wrappers.
Thanks @MilesCranmer @arturbekasov this is a nice addition, but I cannot get sampling to work. Maybe you see a quick fix?
from nflows import transforms, distributions, flows
import torch
# Define an invertible transformation.
transform = transforms.CompositeTransform([
transforms.MaskedAffineAutoregressiveTransform(features=2, hidden_features=4),
transforms.RandomPermutation(features=2)
])
# Define a base distribution.
base_distribution = torch.distributions.Independent(torch.distributions.Uniform(torch.zeros(2), torch.ones(2)), 1)
# Combine into a flow.
flow = flows.Flow(transform=transform, distribution=base_distribution)
flow.sample(1)
gives error message:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/nflows/distributions/base.py", line 70, in sample
return self._sample(num_samples, context)
File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/nflows/flows/base.py", line 56, in _sample
repeat_noise = self._distribution.sample(num_samples*embedded_context.shape[0])
AttributeError: 'NoneType' object has no attribute 'shape'
This allows one to use flow base distributions which lack a context keyword argument, like anything in
torch.distributions.Independent
, which is used forBoxUniform
. It basically just checks at initialization ifcontext
is an available keyword argument toself._distribution.log_prob
, and uses that to decide whether to pass context or not. Otherwise such distributions will give an error.@arturbekasov @imurray
This also fixes #17
Cheers, Miles