google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

forward mode differentiation issue with scipy.special.xlogy #15709

Closed mcmozer-google closed 1 year ago

mcmozer-google commented 1 year ago

Description

A few days ago, a commit to jax.scipy.special seems to be causing my previously-working code to bomb out with NaN errors. I believe the root problem is that you've added a custom jvp to xlogy and xlog1py which need the same safety features as the original functions.

@_wraps(osp_special.xlogy, module='scipy.special')
def xlogy(x: ArrayLike, y: ArrayLike) -> Array:
  # Note: xlogy(0, 0) should return 0 according to the function documentation.
  x, y = promote_args_inexact("xlogy", x, y)
  x_ok = x != 0.
  safe_x = jnp.where(x_ok, x, 1.)
  safe_y = jnp.where(x_ok, y, 1.)
  return jnp.where(x_ok, lax.mul(safe_x, lax.log(safe_y)), jnp.zeros_like(x))

def _xlogy_jvp(primals, tangents):
  (x, y) = primals
  (x_dot, y_dot) = tangents
  result = xlogy(x, y)
  return result, (x_dot * lax.log(y) + y_dot * x / y).astype(result.dtype)
xlogy.defjvp(_xlogy_jvp)

My code is obtaining NaNs in _xlogy_jvp when I call scipy.special.entr(p) which then calls xlogy(p,p) and one of the elements of p is zero.

What jax/jaxlib version are you using?

jax v0.4.9

Which accelerator(s) are you using?

TPU

Additional system info

Colab

NVIDIA GPU info

No response

mattjj commented 1 year ago

Thanks for raising this, and the diagnosis!

@jakevdp is it okay to assign to you since you have some context on that commit?

mattjj commented 1 year ago

@mcmozer-google IIUC $(x, y) \mapsto x \log y$ is not differentiable with respect to its second argument at $(0,0)$; indeed I don't think $0$ is in the domain of $\log$. (We may be able to say $x \mapsto x \log x$ has domain $[0, \infty)$ and further that it is differentiable on its domain (not all of $\mathbb{R}$) at $0$ with derivative $-\infty$...)

What answer were you expecting?

I think perhaps before it was incorrectly returning 0, but isn't nan a better value at a point of non-differentiability (over the reals, not just nonnegative reals)?

mcmozer-google commented 1 year ago

You have a completely valid point. I am looking at it from the following perspective, which also seems valid: jax.scipy.special.xlogy defines xlogy(0,y)=0 for all y. Also, jax.scipy.special.entr, which calls xlogy, defines entr(0)=0. Don't these definitions imply that the derivative of xlogy(0,y) is 0 for all y and that the derivative of entr(0) is 0?

From a practical perspective, the previous version of jax did treat these derivatives as 0 (thanks to safe_x and safe_y), and existing code that uses entr and xlogy may break (as happened to me). If it broke just for weird outlier cases, I wouldn't sweat it, but it seems common to need to compute entropy over a distribution where one probability is 0 (i.e., xlogy(0,0)).

mattjj commented 1 year ago

Yes, I think directional derivatives exist (like along $(0, y)$ for $y > 0$, as you say), but there's not one linear map which works for all directions (e.g. the slope is different in the direction $(y, y)$ for $y > 0$), which is what we'd need for computing Jacobians or for reverse mode, or for forward mode if we're not willing to conditionally switch on the direction. (Since we use the same underlying jvp for Jacobians as well as reverse mode, we generally don't write JVP rules which depend on the tangent vector direction, though custom_jvp rules could.)

I think for the xlogy function we shouldn't define the derivative at $(0, 0)$ in general (i.e. we should keep it as a nan), though users can always apply their own custom_jvp to a wrapper to adopt any convention makes sense for their own applications.

I think the issue here may actually be with jax.scipy.special.entr, in that while it calls the more general function xlogy it actually only needs the function $x \mapsto x \log x$, which is differentiable at $x=0$, though with derivative $-\infty$. Thinking about the plot of the binary entropy function, that seems to give the right answer of \infty for grad(entr)(0.). I can't think of a good reason to make it 0 though...

So, what if we make grad(entr)(0.) == jnp.inf? Would that have worked in your application, or would it still have blown something up?

mcmozer-google commented 1 year ago

grad(entr)(0)==jnp.inf is the right thing to do. (I convinced myself with a finite difference approximation, but I appreciate your insight in just visualizing the entropy function.) I don't believe this solution would help for my application, where I am computing entr(softmax(logits)), and the argument to entr is exactly 0 only for the unusual case where floating point precision isn't adequate. I don't understand the magic of forward differentiation, but won't reverse mode end up with 0 * jnp.inf and result in nans anyhow? (In my application, I am happy to just zero out and lose the gradients in this unusual case. I just made a custom entr function that does this.) Thanks much for your patience, mattjj.

jakevdp commented 1 year ago

15737 makes it so that grad(entr)(0.0) returns infinity, via correct evaluation of the gradient of $f(x) = x\log(x)$