jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.53k stars 2.8k forks source link

Hessian of linalg.norm #10965

Closed johnjmolina closed 2 years ago

johnjmolina commented 2 years ago

Hello,

I'd like to report a possible bug I encountered when computing Hessians of functions involving jax.numpy.linalg.norm. If the input is different from zero, results are OK, however, if the Hessian is evaluated at zero it returns nans.

Thanks in advance!

import jax
T = lambda v : jax.numpy.linalg.norm(v)
jax.hessian(T)(0.0)

# returns DeviceArray(nan, dtype=float64, weak_type=True)
nicholasjng commented 2 years ago

Hey! Not a maintainer, but mathematically, that result actually makes sense: The Euclidean norm is not differentiable at the origin, and hence, since gradients do not exist, the Hessian will not exist, either.

For a proof of this, see this thread on SO (among others, this seems to be a frequently asked question): https://math.stackexchange.com/questions/1175277/why-isnt-the-l-2-norm-differentiable-at-x-0

An immediate corollary of that discussion is that the sequence $h \rightarrow 0$ that you use to form the partial derivative will determine the value of that derivative along that direction (in math, this is usually called a "directional derivative"). But, from a library developer point of view, if you were to choose a limit here, say, from the desire to avoid NaN values, which one would it be? This ambiguity poses more problems than it solves here (e.g., how would people without a math background debug unexpected values?), and thus it is (imo) good behavior that the result is NaN.

That being said, if you want to fix a desired zero-behavior in your function, I think your best bet might be defining a "piecewise" gradient using jnp.where with a conditional on the input 0. Mathematically, this is changing the gradient on a null set, which preserves things like Lebesgue integrals, but it can be used to conveniently eliminate numerical issues like this. I think there is a section in the docs on that, as well as some previous discussions, since this problem comes up frequently, albeit in slightly different formulations (I think @jakevdp is quite the expert on it).

Let me know if anything is unclear, then I am happy to clarify :)

johnjmolina commented 2 years ago

Thanks! Yes of course, I should have noticed this before posting. I'm not sure if its really a bug, but I still get nan's when computing the hessian of np.linalg.norm(v)**2 (at v=0), even though this one mathematically should be defined.

nicholasjng commented 2 years ago

That is also expected. JAX sees your function not as a sum of squares, like the squared Euclidean norm actually is on paper, but really more as "compute the norm via jnp.linalg.norm and square the result". If you step through this with the gradient, applying the chain rule, you end up evaluating the gradient of jnp.linalg.norm again, and you're back at the original problem.

If you know that you'll be computing the squared Euclidean norm a priori, then you could explicitly formulate it as a sum of squares:

def euclnormsquared(x):
     return jnp.sum(x * x)

This has none of the kinks of the norm described above, and should be quite fast, since it is only one BLAS op and one (vectorizable) sum - it even parallelizes trivially. Let me know if that helps.

johnjmolina commented 2 years ago

Thanks again for the prompt reply, yes this is what I ended up doing.