Closed johannbrehmer closed 3 years ago
Hi Johann,
Hope you've been well!
Good catch: in my mind we've had this fixed in nflows
, but that's not true: we've had a PR in the nsf
repo to fix this, but it never got merged as we've been in the process of factoring things out to nflows
(argh!).
You're right, if torch.any(inside_interval_mask)
is all that's needed. Ideally a few simple tests should be added to check this behaviour.
If you're happy to PR -- that'd be amazing. No worries otherwise, I can pick it up.
Thanks!
Artur
Hi Artur,
thanks for the reply (and I hope you're doing well as well!).
I just submitted PR #25 including some simple tests.
Cheers, Johann
This has been resolved in #25. Thank you very much, @johannbrehmer!
Hi all,
very rarely, I get an error when using spline flows in
nflows
:Issue
When a spline transformations encounters inputs that are all outside the
(-tail_bound, tail_bound)
range, it will throw a RuntimeError. For instance,gives me
(This was with nflows v0.12 on pypi.)
Fix
This is very simple to fix by just adding a check like
if torch.any(inside_interval_mask):
before the calls to the spline functions. I would be happy to open a PR if you are interested.Cheers, Johann