bayesiains / nflows

Normalizing flows in PyTorch
MIT License
845 stars 118 forks source link

RuntimeError when splines encounter all-tail inputs #23

Closed johannbrehmer closed 3 years ago

johannbrehmer commented 4 years ago

Hi all,

very rarely, I get an error when using spline flows in nflows:

Issue

When a spline transformations encounters inputs that are all outside the (-tail_bound, tail_bound) range, it will throw a RuntimeError. For instance,

import torch
import nflows.transforms

x = torch.tensor([[5.], [-6.]])
trf = nflows.transforms.PiecewiseLinearCDF(shape=(1,), tails="linear", tail_bound=4.0)

trf(x)

gives me

RuntimeError                              Traceback (most recent call last)
<ipython-input-19-39a279c296f6> in <module>
      5 trf = nflows.transforms.PiecewiseLinearCDF(shape=(1,), tails="linear", tail_bound=4.0)
      6 
----> 7 trf(x)

~/anaconda3/envs/flow_processes/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/flow_processes/lib/python3.8/site-packages/nflows/transforms/nonlinearities.py in forward(self, inputs, context)
    240 
    241     def forward(self, inputs, context=None):
--> 242         return self._spline(inputs, inverse=False)
    243 
    244     def inverse(self, inputs, context=None):

~/anaconda3/envs/flow_processes/lib/python3.8/site-packages/nflows/transforms/nonlinearities.py in _spline(self, inputs, inverse)
    229             )
    230         else:
--> 231             outputs, logabsdet = splines.unconstrained_linear_spline(
    232                 inputs=inputs,
    233                 unnormalized_pdf=unnormalized_pdf,

~/anaconda3/envs/flow_processes/lib/python3.8/site-packages/nflows/transforms/splines/linear.py in unconstrained_linear_spline(inputs, unnormalized_pdf, inverse, tail_bound, tails)
     22         raise RuntimeError("{} tails are not implemented.".format(tails))
     23 
---> 24     outputs[inside_interval_mask], logabsdet[inside_interval_mask] = linear_spline(
     25         inputs=inputs[inside_interval_mask],
     26         unnormalized_pdf=unnormalized_pdf[inside_interval_mask, :],

~/anaconda3/envs/flow_processes/lib/python3.8/site-packages/nflows/transforms/splines/linear.py in linear_spline(inputs, unnormalized_pdf, inverse, left, right, bottom, top)
     42     > Müller et al., Neural Importance Sampling, arXiv:1808.03856, 2018.
     43     """
---> 44     if torch.min(inputs) < left or torch.max(inputs) > right:
     45         raise InputOutsideDomain()
     46 

RuntimeError: operation does not have an identity.

(This was with nflows v0.12 on pypi.)

Fix

This is very simple to fix by just adding a check like if torch.any(inside_interval_mask): before the calls to the spline functions. I would be happy to open a PR if you are interested.

Cheers, Johann

arturbekasov commented 4 years ago

Hi Johann,

Hope you've been well!

Good catch: in my mind we've had this fixed in nflows, but that's not true: we've had a PR in the nsf repo to fix this, but it never got merged as we've been in the process of factoring things out to nflows (argh!).

You're right, if torch.any(inside_interval_mask) is all that's needed. Ideally a few simple tests should be added to check this behaviour.

If you're happy to PR -- that'd be amazing. No worries otherwise, I can pick it up.

Thanks!

Artur

johannbrehmer commented 4 years ago

Hi Artur,

thanks for the reply (and I hope you're doing well as well!).

I just submitted PR #25 including some simple tests.

Cheers, Johann

arturbekasov commented 3 years ago

This has been resolved in #25. Thank you very much, @johannbrehmer!