Closed smiet closed 3 years ago
Thanks for the positive words about JAX!
This is a pretty subtle issue, and it's actually one that's fundamental to autodiff in bulk-array programming languages (i.e. languages that look like NumPy, PyTorch, TF). This comment on #1052 discusses it in some detail. There's also an FAQ entry, but that entry may be misleading in this case because of how it sounds like the issue is specific to jnp.where
(even though this often comes up when people try to remove singularities using jnp.where
).
The short (over-simplified!) answer is that to prevent nans from contaminating the computation in reverse-mode autodiff, you need to guard against the singularity both before and after the line that would generate it:
import jax.numpy as np
from jax import jit, grad
from jax.ops import index_update, index
def sliced_sum_array(a):
points = np.linspace(0., a, 100)
points = points.at[0].set(1.) # added guard *before* the reciprocal
points = 1 / points
points_cleaned = points.at[0].set(0.)
return np.sum(points_cleaned[0:])
grad_slice = grad(sliced_sum_array, argnums=0)
print("the sum skipping the singularity is {}".format(sliced_sum_array(12.)))
print("the gradient skipping the singularity is {}".format(grad_slice(12.)))
the sum skipping the singularity is 42.7133674621582
the gradient skipping the singularity is -3.5594468116760254
See the thread on #1052 for a more complete explanation.
WDYT?
Thank you so much! This really helped me understand how the nans propagate and how to avoid them in gradients (I'd read #1052 but not immediately seen the connection).
I am loving JAX, and coding with it has been amazing.
I am encountering unexpected behavior when I try to calculate the gradient of integrals which contain singularities, even when I replace these singularities with constant values or skip over them by array slicing. The simplest example of this behavior is in the following code:
which yields
This also occurs if the infinity is replaced with a constant value, as in:
in fact, returning any element from the array in which a divide-by-zero is encountered results in a nan in the gradient:
output:
Here is a more elaborate example in which the singularity occurs in the middle of an array:
Where the approximation to the integral is pretty good, but the gradient returns a nan: