jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.35k stars 2.78k forks source link

expit outputs 0 when maybe it shouldn't - but it's probably fine #4939

Closed SamPruden closed 3 years ago

SamPruden commented 3 years ago

The expit function outputs 0 for input values with reasonably small magnitude. On CPU jax.scipy.special.expit(-87.337) outputs 0, and on GPU jax.scipy.special.expit(-88.723) outputs 0. The logistic function this is implementing has range 0 < y < 1, so to be fully correct the implementation should never be able to output 0 or 1. I assume that we're just hitting the limits of floating point precision here.

A consequence of this is that certain mathematical expressions break. For example, dividing by the output of expit is not safe in JAX, although it is safe in pure maths. I just hit that NaN bug in my code. I obviously should have anticipated that given the floating point asymptote.

It might be a good idea to clamp the outputs in the strict 0 < expit(x) < 1 range somehow, so that the value it approaches is slightly above 0.

The counterargument

This behaviour also exists in base scipy, so this is actually a correct and faithful implementation. The performance cost of the clamping operation is probably not justified by the small improvement in correctness. PyTorch also does the same thing, so it looks like this implementation is the consensus.

I think this probably shouldn't be changed, and could perhaps even be described as working as intended. However I thought it was worth raising as this probably isn't carefully designed behaviour, and is maybe worth a little bit of thought. Perhaps a documentation note could be added stating the domain over which 0 < y < 1 is guaranteed to hold? I found it surprisingly small, as -88.337 is well within the range of values that might occur in the wild.

mattjj commented 3 years ago

Thanks for the clear writeup!

It's ambiguous, and there are defensible arguments either way (as you've pointed out), but I tend to come down on the side of saying the current behavior is okay. Dividing by the value of an expit isn't really safe even when it's nonzero (i.e. even if we clamped its value to something very small) in floating point arithmetic. I think foisting this kind of issue on the user, in the sense of requiring the user to write an explicit jnp.where or jnp.clamp in their code rather than hiding it in our implementation, is actually a good thing.

As for documentation, given that we tend to inherit NumPy/SciPy docstrings (programmatically), and also that this behavior exists in SciPy, maybe the best way to improve everyone's docstrings would be to make a change in SciPy. WDYT?

SamPruden commented 3 years ago

I agree. In fact, 20 minutes after posting this I felt a bit silly for doing so. Your job here is to re-implement those other libraries and you're currently doing that correctly. The bug that lead to me posting this was my own fault. I had naively moved an expression from Mathematica into JAX and was surprised to find functions that are supposedly the same behaving differently. That was silly of me. It's floating point, of course that's expected! I posted here too rashly.

The argument in favour of putting this into documentation would be to prevent other people making the same silly mistake that I did, but honestly it's probably a rare enough scenario that it's not justified. As you say, the SciPy repo is probably the place to post this if it's worth documenting, although I probably won't bother. I'm happy to consider the issue closed here if you are.

Your more familiar with the ecosystem than I am, so if you think there's real value in posting it to the SciPy repo I'm happy to be encouraged, but I doubt that there is.

mattjj commented 3 years ago

It seems like making a doc update to SciPy would have nonnegative value, but it doesn't seem like it'd be a high priority relative to other stuff we could spend time on!

Thanks for raising this. I think it was worth thinking about. And if it comes up again, we'll have some context and perhaps be able to re-evaluate whether we should take a particular action. For now, though, I'll close this!