jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30k stars 2.75k forks source link

numerical issues with 4th deriv of sinc near zero #5094

Open mattjj opened 3 years ago

mattjj commented 3 years ago

Even after #5077 fixed a bug with higher-order derivatives of jnp.sinc at zero, there remains a numerical stability issue with 4th and higher derivatives near zero:

import matplotlib.pyplot as plt
from jax import grad, vmap
import jax.numpy as jnp

xs = jnp.linspace(-5, 5, 1001)
plt.plot(xs, vmap(grad(jnp.sinc))(xs))
plt.plot(xs, vmap(grad(grad(jnp.sinc)))(xs))
plt.plot(xs, vmap(grad(grad(grad(jnp.sinc))))(xs))
plt.plot(xs, vmap(grad(grad(grad(grad(jnp.sinc)))))(xs))
plt.savefig('sinc_4th.png')

sinc_4th

jakevdp commented 3 years ago

Interestingly, the numerical issues appear unrelated to the Maclaurin trick added in #5077

import jax.numpy as jnp
from jax import vmap, grad
import matplotlib.pyplot as plt

def sinc(x):
  x = jnp.pi * jnp.asarray(x)
  safe_x = jnp.where(x == 0, 1, x)
  return jnp.where(x == 0, 0, jnp.sin(safe_x) / safe_x)

xs = jnp.linspace(-5, 5, 1001)
plt.plot(xs, vmap(grad(sinc))(xs))
plt.plot(xs, vmap(grad(grad(sinc)))(xs))
plt.plot(xs, vmap(grad(grad(grad(sinc))))(xs))
plt.plot(xs, vmap(grad(grad(grad(grad(sinc)))))(xs))

download-1

mattjj commented 3 years ago

Oh yeah, I didn't mean to suggest that these instabilities were due to #5077; #5077 can only affect the values at an input of exactly zero, whereas these instabilities are happening near zero (since we only see them when doing linspace(-5, 5, 1001) but not linspace(-5, 5, 101), both of which include zero as input, but the former includes smaller nonzero values).

Rather, I think the issue is that differentiating the sin(pi * x) / (pi * x) expression is itself leading to numerically unstable expressions for small-ish input values. It feels plausible that would happen since we're dividing by x, and so it seems reasonable that replacing the function with some kind of series expansion near zero (rather than just at zero) could have better numerical behavior. But I'm not sure of what expression to use, or how to rig it up nicely...

gnool commented 3 years ago

I wonder how one would solve such numerical instability issue even outside the context of auto-differentiation. For example, the analytical form of 4th order derivative of sinc(x) is also suffering from numerical instability when x is small (perhaps when x < 0.001 for double precision) due to the large exponent on x in the denominator.

Outside the context of auto-differentiation, if a smooth function is required for small x I would imagine myself writing a function that first checks if the input x could trigger numerical issues (not sure if there is even such heuristic, but assuming there is), and if it does, then replace the original expression with its Taylor expansion about x=0. It should be possible to determine where to truncate the series based on machine epsilon and considering that sinc(x) has a relatively simple expansion the coefficients are readily available too.

Very eager to see what tricks you guys will eventually pull for JAX!

(Never mind my comment, just saw that @mattjj has said something similar earlier)

GJBoth commented 3 years ago

Rather, I think the issue is that differentiating the sin(pi * x) / (pi * x) expression is itself leading to numerically unstable expressions for small-ish input values. It feels plausible that would happen since we're dividing by x, and so it seems reasonable that replacing the function with some kind of series expansion near zero (rather than just at zero) could have better numerical behavior. But I'm not sure of what expression to use, or how to rig it up nicely...

I checked with mathematica; evaluating the sinc derivatives at zero leads to a 1 / 0 error. I have to take the limit to get a value (which incidentally is 1/5, which seems very far off from what you're plotting?), so it seems you're right.