jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

grad of function containing skipped singularities is nan #4915

Closed smiet closed 3 years ago

smiet commented 3 years ago

I am loving JAX, and coding with it has been amazing.

I am encountering unexpected behavior when I try to calculate the gradient of integrals which contain singularities, even when I replace these singularities with constant values or skip over them by array slicing. The simplest example of this behavior is in the following code:

import jax.numpy as np
from jax import jit, grad
from jax.ops import index_update, index

def sliced_sum_array(a):
    points = 1 / np.linspace(0., a, 100)
    return np.sum(points[1:])

grad_slice = grad(sliced_sum_array, argnums=0)

print("the sum skipping the singularity is {}".format(sliced_sum_array(12.)))
print("the gradient skipping the singularity is {}".format(grad_slice(12.)))

which yields

the sum skipping the singularity is 42.71336364746094
the gradient skipping the singularity is nan

This also occurs if the infinity is replaced with a constant value, as in:

def sliced_sum_array(a):
    points = 1 / np.linspace(0., a, 100)
    points_cleaned = points.at[0].set(0.)
    return np.sum(points_cleaned[0:])

in fact, returning any element from the array in which a divide-by-zero is encountered results in a nan in the gradient:

def sliced_sum_array(a):
    points = 1 / np.linspace(0., a, 100)
    return points[-1]

grad_slice = grad(sliced_sum_array, argnums=0)

print("the last element of the array is {}".format(sliced_sum_array(12.)))
print("the gradient of the last element is {}".format(grad_slice(12.)))

output:

the last element of the array is 0.0833333358168602
the gradient of the last element is nan

Here is a more elaborate example in which the singularity occurs in the middle of an array:

import jax.numpy as np
from jax import jit, grad
from jax.ops import index_update, index

def singularity_containing_integral(a, b):
    """
    approximate the integral of 1/(a*(x-b)) from zero to 10
    a: scalar
    b: integer
    """
    points = a * (np.linspace(0, 10, 1001) - b
                  )  # calculate the the denominator
    inverse = 1 / points  # invert
    inverse_cleaned = index_update(inverse, index[100*b], 0) # replace the infinity
    return np.sum(inverse_cleaned *
                  (10 / 1001))  # multiply by step size and sum

print("the integral of 1/(5*(3-x)) from 0 to 10 \
        approximately equals {}".format(singularity_containing_integral(5.,
                                                                        3)))

# gradient function w.r.t. first argument:
grad_int = grad(singularity_containing_integral, argnums=0)

print("The gradient does not compute: {}".format(grad_int(5., 3)))

Where the approximation to the integral is pretty good, but the gradient returns a nan:

the integral of 1/(5*(3-x)) from 0 to 10         approximately equals 0.1691005975008011
The gradient does not compute: nan
mattjj commented 3 years ago

Thanks for the positive words about JAX!

This is a pretty subtle issue, and it's actually one that's fundamental to autodiff in bulk-array programming languages (i.e. languages that look like NumPy, PyTorch, TF). This comment on #1052 discusses it in some detail. There's also an FAQ entry, but that entry may be misleading in this case because of how it sounds like the issue is specific to jnp.where (even though this often comes up when people try to remove singularities using jnp.where).

The short (over-simplified!) answer is that to prevent nans from contaminating the computation in reverse-mode autodiff, you need to guard against the singularity both before and after the line that would generate it:

import jax.numpy as np
from jax import jit, grad
from jax.ops import index_update, index

def sliced_sum_array(a):
    points = np.linspace(0., a, 100)
    points = points.at[0].set(1.)  # added guard *before* the reciprocal
    points = 1 / points
    points_cleaned = points.at[0].set(0.)
    return np.sum(points_cleaned[0:])

grad_slice = grad(sliced_sum_array, argnums=0)

print("the sum skipping the singularity is {}".format(sliced_sum_array(12.)))
print("the gradient skipping the singularity is {}".format(grad_slice(12.)))
the sum skipping the singularity is 42.7133674621582
the gradient skipping the singularity is -3.5594468116760254

See the thread on #1052 for a more complete explanation.

WDYT?

smiet commented 3 years ago

Thank you so much! This really helped me understand how the nans propagate and how to avoid them in gradients (I'd read #1052 but not immediately seen the connection).