jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

Inconsistent reverse-mode derivatives with out-of-bounds array indexing #5760

Open matt-graham opened 3 years ago

matt-graham commented 3 years ago

tl;dr: While JAX clamps indices to the array bounds when evaluating array indexing operations, the same logic does not appear to be applied when back-propagating derivatives through the indexing operations in reverse-mode differentiation.

I came across an error in some code I had written which was only being reflected in the gradients computed by JAX but not the forward function evaluation. After a bit of digging I discovered this was caused by an out-by-one error when indexing the last element in an array. Due to the clamping in JAX's out-of-bounds array indexing behaviour the last element in the array was gathered even though the index was for the non-existent (last + 1)th element.

As this is documented behaviour and there seems to be good reason for the departure from NumPy's behaviour this is not in itself an issue. What caused some confusion to me however was that the clamping of the indices does not appear to be accounted for when back-propagating derivatives, meaning that the gradient of the function evaluated using jax.grad was inconsistent with, for example, a numerical finite difference check, however the clamping is accounted for when instead evaluating derivatives using forward-mode differentiation e.g. using jax.jacfwd . As a simple example the following code,

import jax.numpy as np
from jax import grad, jacfwd

def func(x):
    return x[x.shape[0]]

def numerical_grad(f, h=1e-8):

    def grad_f(x):
        return np.array([
            (f(x + h * e_i) - f(x - h * e_i)) / (2 * h)
            for e_i in np.eye(x.shape[0])
        ])

    return grad_f

x = np.arange(2.)
print(f'func(x) = {func(x)}')
print(f'grad(func)(x) = {grad(func)(x)}')
print(f'jacfwd(func)(x) = {jacfwd(func)(x)}')
print(f'numerical_grad(func)(x) = {numerical_grad(func)(x)}')

produces the output

func(x) = 1.0
grad(func)(x) = [0. 0.]
jacfwd(func)(x) = [0. 1.]
numerical_grad(func)(x) = [0. 1.]

From looking at the jaxprs for func and grad(func) the issue potentially seems to arise as while the indexing logic in gather clamps out-of-bounds indices, the same does not appear to be the case for the scatter_add primitive that seems to be used to back-propagate derivatives through the indexing operation (while another gather operation is used when instead forward propagating derivatives).

I am not sure what the expected behaviour should actually be here, and so whether this actually a useful issue to raise or not. On the one hand it seems that if the out-of-bounds indexing behaviour is a documented part of JAX's API, derivatives calculated of functions (ab)using this behaviour should be consistent with indexing semantics used to evaluate the functions themselves. Equally however, the current inconsistency did help alert me to a bug in my code, and its not clear that having consistent gradients of buggy code is desirable!

shoyer commented 3 years ago

I think it might be better to think of out-of-bounds indexing as undefined behavior, in which case the docs should be updated to clarify this point. In particular, I would not want to significantly slow down the calculation of gradients of indexing for all use-cases to handle this case.

For reference: the main reason why XLA clamps rather than raising an error is that XLA doesn't raise errors at runtime (only for invalid types).

jekbradbury commented 3 years ago

It's definitely safest to think of out-of-bounds indexing as undefined behavior, but I think there are also algorithmic situations when it's nice to be able to rely on those semantics (maybe one way to find out would be to implement a translation of lax.gather with different OOB behavior and see what tests break). In any case a docs note seems worthwhile.

matt-graham commented 3 years ago

Thanks for the clarifications @shoyer and @jekbradbury, that all makes sense. The additional documentation note in #5813 resolves the issue from my perspective, so I can either close this issue now or can it be linked to #5813 so that it is closed when that's merged?

x10000year commented 2 years ago

Why the behavior of gather is inconsistent with tensorflow's gather? tensorflow's gather returns 0 for out of bound index. How is tensorflow's gather translated to XLA?