Open matt-graham opened 3 years ago
I think it might be better to think of out-of-bounds indexing as undefined behavior, in which case the docs should be updated to clarify this point. In particular, I would not want to significantly slow down the calculation of gradients of indexing for all use-cases to handle this case.
For reference: the main reason why XLA clamps rather than raising an error is that XLA doesn't raise errors at runtime (only for invalid types).
It's definitely safest to think of out-of-bounds indexing as undefined behavior, but I think there are also algorithmic situations when it's nice to be able to rely on those semantics (maybe one way to find out would be to implement a translation of lax.gather with different OOB behavior and see what tests break). In any case a docs note seems worthwhile.
Thanks for the clarifications @shoyer and @jekbradbury, that all makes sense. The additional documentation note in #5813 resolves the issue from my perspective, so I can either close this issue now or can it be linked to #5813 so that it is closed when that's merged?
Why the behavior of gather is inconsistent with tensorflow's gather? tensorflow's gather returns 0 for out of bound index. How is tensorflow's gather translated to XLA?
tl;dr: While JAX clamps indices to the array bounds when evaluating array indexing operations, the same logic does not appear to be applied when back-propagating derivatives through the indexing operations in reverse-mode differentiation.
I came across an error in some code I had written which was only being reflected in the gradients computed by JAX but not the forward function evaluation. After a bit of digging I discovered this was caused by an out-by-one error when indexing the last element in an array. Due to the clamping in JAX's out-of-bounds array indexing behaviour the last element in the array was gathered even though the index was for the non-existent (last + 1)th element.
As this is documented behaviour and there seems to be good reason for the departure from NumPy's behaviour this is not in itself an issue. What caused some confusion to me however was that the clamping of the indices does not appear to be accounted for when back-propagating derivatives, meaning that the gradient of the function evaluated using
jax.grad
was inconsistent with, for example, a numerical finite difference check, however the clamping is accounted for when instead evaluating derivatives using forward-mode differentiation e.g. usingjax.jacfwd
. As a simple example the following code,produces the output
From looking at the jaxprs for
func
andgrad(func)
the issue potentially seems to arise as while the indexing logic ingather
clamps out-of-bounds indices, the same does not appear to be the case for thescatter_add
primitive that seems to be used to back-propagate derivatives through the indexing operation (while anothergather
operation is used when instead forward propagating derivatives).I am not sure what the expected behaviour should actually be here, and so whether this actually a useful issue to raise or not. On the one hand it seems that if the out-of-bounds indexing behaviour is a documented part of JAX's API, derivatives calculated of functions (ab)using this behaviour should be consistent with indexing semantics used to evaluate the functions themselves. Equally however, the current inconsistency did help alert me to a bug in my code, and its not clear that having consistent gradients of buggy code is desirable!