Closed SamPruden closed 3 years ago
Thanks for the clear writeup!
It's ambiguous, and there are defensible arguments either way (as you've pointed out), but I tend to come down on the side of saying the current behavior is okay. Dividing by the value of an expit isn't really safe even when it's nonzero (i.e. even if we clamped its value to something very small) in floating point arithmetic. I think foisting this kind of issue on the user, in the sense of requiring the user to write an explicit jnp.where
or jnp.clamp
in their code rather than hiding it in our implementation, is actually a good thing.
As for documentation, given that we tend to inherit NumPy/SciPy docstrings (programmatically), and also that this behavior exists in SciPy, maybe the best way to improve everyone's docstrings would be to make a change in SciPy. WDYT?
I agree. In fact, 20 minutes after posting this I felt a bit silly for doing so. Your job here is to re-implement those other libraries and you're currently doing that correctly. The bug that lead to me posting this was my own fault. I had naively moved an expression from Mathematica into JAX and was surprised to find functions that are supposedly the same behaving differently. That was silly of me. It's floating point, of course that's expected! I posted here too rashly.
The argument in favour of putting this into documentation would be to prevent other people making the same silly mistake that I did, but honestly it's probably a rare enough scenario that it's not justified. As you say, the SciPy repo is probably the place to post this if it's worth documenting, although I probably won't bother. I'm happy to consider the issue closed here if you are.
Your more familiar with the ecosystem than I am, so if you think there's real value in posting it to the SciPy repo I'm happy to be encouraged, but I doubt that there is.
It seems like making a doc update to SciPy would have nonnegative value, but it doesn't seem like it'd be a high priority relative to other stuff we could spend time on!
Thanks for raising this. I think it was worth thinking about. And if it comes up again, we'll have some context and perhaps be able to re-evaluate whether we should take a particular action. For now, though, I'll close this!
The
expit
function outputs 0 for input values with reasonably small magnitude. On CPUjax.scipy.special.expit(-87.337)
outputs 0, and on GPUjax.scipy.special.expit(-88.723)
outputs 0. The logistic function this is implementing has range0 < y < 1
, so to be fully correct the implementation should never be able to output 0 or 1. I assume that we're just hitting the limits of floating point precision here.A consequence of this is that certain mathematical expressions break. For example, dividing by the output of
expit
is not safe in JAX, although it is safe in pure maths. I just hit that NaN bug in my code. I obviously should have anticipated that given the floating point asymptote.It might be a good idea to clamp the outputs in the strict
0 < expit(x) < 1
range somehow, so that the value it approaches is slightly above 0.The counterargument
This behaviour also exists in base scipy, so this is actually a correct and faithful implementation. The performance cost of the clamping operation is probably not justified by the small improvement in correctness. PyTorch also does the same thing, so it looks like this implementation is the consensus.
I think this probably shouldn't be changed, and could perhaps even be described as working as intended. However I thought it was worth raising as this probably isn't carefully designed behaviour, and is maybe worth a little bit of thought. Perhaps a documentation note could be added stating the domain over which
0 < y < 1
is guaranteed to hold? I found it surprisingly small, as -88.337 is well within the range of values that might occur in the wild.