Closed johnjmolina closed 2 years ago
Hey! Not a maintainer, but mathematically, that result actually makes sense: The Euclidean norm is not differentiable at the origin, and hence, since gradients do not exist, the Hessian will not exist, either.
For a proof of this, see this thread on SO (among others, this seems to be a frequently asked question): https://math.stackexchange.com/questions/1175277/why-isnt-the-l-2-norm-differentiable-at-x-0
An immediate corollary of that discussion is that the sequence $h \rightarrow 0$ that you use to form the partial derivative will determine the value of that derivative along that direction (in math, this is usually called a "directional derivative"). But, from a library developer point of view, if you were to choose a limit here, say, from the desire to avoid NaN values, which one would it be? This ambiguity poses more problems than it solves here (e.g., how would people without a math background debug unexpected values?), and thus it is (imo) good behavior that the result is NaN.
That being said, if you want to fix a desired zero-behavior in your function, I think your best bet might be defining a "piecewise" gradient using jnp.where
with a conditional on the input 0
. Mathematically, this is changing the gradient on a null set, which preserves things like Lebesgue integrals, but it can be used to conveniently eliminate numerical issues like this. I think there is a section in the docs on that, as well as some previous discussions, since this problem comes up frequently, albeit in slightly different formulations (I think @jakevdp is quite the expert on it).
Let me know if anything is unclear, then I am happy to clarify :)
Thanks! Yes of course, I should have noticed this before posting. I'm not sure if its really a bug, but I still get nan's when computing the hessian of np.linalg.norm(v)**2 (at v=0), even though this one mathematically should be defined.
That is also expected. JAX sees your function not as a sum of squares, like the squared Euclidean norm actually is on paper, but really more as "compute the norm via jnp.linalg.norm
and square the result". If you step through this with the gradient, applying the chain rule, you end up evaluating the gradient of jnp.linalg.norm
again, and you're back at the original problem.
If you know that you'll be computing the squared Euclidean norm a priori, then you could explicitly formulate it as a sum of squares:
def euclnormsquared(x):
return jnp.sum(x * x)
This has none of the kinks of the norm described above, and should be quite fast, since it is only one BLAS op and one (vectorizable) sum - it even parallelizes trivially. Let me know if that helps.
Thanks again for the prompt reply, yes this is what I ended up doing.
Hello,
I'd like to report a possible bug I encountered when computing Hessians of functions involving jax.numpy.linalg.norm. If the input is different from zero, results are OK, however, if the Hessian is evaluated at zero it returns nans.
Thanks in advance!