bayesiains / nflows

Normalizing flows in PyTorch
MIT License
845 stars 118 forks source link

Fix splines crashing with all-tail inputs #25

Closed johannbrehmer closed 3 years ago

johannbrehmer commented 4 years ago

This PR fixes issue #23: unconstrained spline transforms crashing when they encounter input that is entirely outside the [-tail_bound, tail_bound] region. It does this by adding a single line of the form if torch.any(inside_interval_mask): to all unconstrained_*_spline() functions.

It also adds corresponding unit tests in tests/transforms/splines/*_test.py.

johannbrehmer commented 4 years ago

Just checking in @arturbekasov : is there anything else you'd like me to do for this PR?

arturbekasov commented 3 years ago

LGTM. (Sorry it took me a while to get to this!)