bayesiains / nflows

Normalizing flows in PyTorch
MIT License
845 stars 118 forks source link

float precision check in rational quadratic spline discriminant #71

Open francesco-vaselli opened 1 year ago

francesco-vaselli commented 1 year ago

Hello, and thanks again for the terrific package

The present pull request address an issue me and my group have been facing a lot when using model.sample(). When inverting a rational quadratic spline one must calculate the value $\sqrt{b^{2} - 4ac}$ where a, b, c are defined as in eq 6 and onward in the original paper (Neural Spline Flows).

The rational quadratic splines code has a check to ensure that this discriminant satisfies $b^2 -4ac >= 0$:

        discriminant = b.pow(2) - 4 * a * c
        assert (discriminant >= 0).all()

However, if the two components are equal up to the float precision, $b^2 = 4ac$, instead of their difference being 0 as expected, it is sometimes set to an arbitrary bit value with arbitrary sign (e.g. -3.0518e-5), which can proc the AssertionError and cause a crash.

To avoid this unwanted behaviour we implemented a simple check on the relative magnitude of the discriminant, which seems to be solving the issue effectively in our use case:

        discriminant = b.pow(2) - 4 * a * c

        float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6
        discriminant[float_precision_mask] = 0

        assert (discriminant >= 0).all()

Please let me know if you need anything else on my part, Best regards, Francesco

PLEASE NOTE: on my fork this commit causes two tests to crash with the following errors: FAILED tests/transforms/autoregressive_test.py::MaskedPiecewiseQuadraticAutoregressiveTranformTest::test_forward_inverse_are_consistent - AssertionError: The tensors are different! FAILED tests/transforms/splines/cubic_test.py::CubicSplineTest::test_forward_inverse_are_consistent - AssertionError: The tensors are different! However this errors seems to be related to two transforms left untouched by my pull request, so I am not sure if they are actually related to my modification. Do you have any idea as to why they may be crashing?

arturbekasov commented 1 year ago

P.S. Re. failing tests: unfortunately, the test can be flaky in this package due to numerical instability. Re-running usually helps.

francesco-vaselli commented 1 year ago

Thanks for reviewing! I have incorporated the required changes, let me know if I can do anything else before we can proceed with a merge! Best, Francesco

imurray commented 1 year ago

From a quick look, I have picky comments, suggesting that a deeper dive might be useful, but not answers:

The 1e-6 is a few times single floating point eps and makes sense given how the discriminant is then used. The 1e-8 seems like a hack: there doesn't seem to be an obvious absolute scale that determines what makes b "small" here?

If b is zero then this operation could introduce a new divide by zero that wasn't there before. If both b and the discriminant is zero, the correct root is usually 0 not NaN. The other solution to the quadratic would pick that up. I wonder if there is a reason that the paper assumed the solution form in the code (eq 29, for 4ac small) was the correct one, whereas apparently you're hitting cases where 4ac is as large as it can be?

I'm not seeing the need to modify small discriminants. It seems safer to only modify negative ones? Perhaps (without having worked out the details) it would be better to first assert that none of them are "too" negative, and then zero out the negative ones.

francesco-vaselli commented 1 year ago

Hello and thank you for taking the time to review the pull request and for your insightful comments!

  1. On Magic Constants: I understand the concern regarding the arbitrary choice of the constants (1e-6 and 1e-8). These were selected primarily based on empirical observations to ensure numerical stability in our specific use-cases. I'm open to exploring a more principled approach. Would you recommend any alternatives?

  2. Division by Zero: Your point about introducing a potential division by zero is valid. My intention was not to override the cases where both b and the discriminant are zero, as the correct root in such a scenario should indeed be zero, not NaN.

  3. Modifying Small Discriminants: The primary reason for modifying small discriminants was to mitigate the effects of floating-point arithmetic errors that can sometimes cause the discriminant to be a tiny negative number, even when b**2 and 4*a*c are theoretically identical. I opted to alter small discriminants to prevent these tiny errors from propagating further into calculations.

    However, I do agree that a more conservative approach of only modifying the negative ones could be more appropriate. Just to understand, would something like this be preferrable?:

    discriminant = torch.where(discriminant<0, 
                                torch.zeros_like(discriminant), discriminant)

Would love to hear your thoughts on these points. Once again, thank you for your time and expertise.

Best regards, Francesco

imurray commented 1 year ago

My advice would be to try to nice to understand why you are hitting this issue when others haven't. Numerical problems often hint at trying to do something unreasonable or strange, that can encourage you to improve what you're doing. It's also possible you're in a regime that the code could generally serve better, for example solving the quadratic differently, but we'd want to understand it.

Regardless, I think we can address this issue without the magic 1e-8 by asserting:

        assert (discriminant >= -1e-6 * b.pow(2)).all()

and then zeroing out any negative values as you suggest. I think we can then remove the later assert.

This new assert allows the discriminant to be "slightly" negative, due to round-off errors of a few eps times the final numbers involved. After that we can zero out those negative values as you suggest, so we don't crash. We'll still get a divide by zero if both b and discriminant are zero. But by leaving all positive discriminants alone, at least we haven't introduced any new crashes.

I think this should solve your problem? But I haven't tried it, so please do. It's possible you'll need to replace b.pow(2) with something like torch.maximum(b.pow(2), torch.abs(4*a*c)) if you're getting really tiny values where b.pow(2) could be zero, and the other part tiny but non-zero. Erm, but then you'd probably be seeing divide by zero problems too, which we'd need to address, so I doubt it.

It think @arturbekasov was asking to name the 1e-6. I'm not sure what he wanted. eps_like_tol? Or whether some existing tolerance should be re-used.