jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.54k stars 2.81k forks source link

Autodiff bug with jnp.linalg.norm #6484

Closed harwiltz closed 3 years ago

harwiltz commented 3 years ago

Please:

I'm not sure if there's an elegant way to fix this. I noticed that when I differentiate jnp.linalg.norm when the input is 0, the gradient is nan. I believe this is due to the vanishing denominator in the derivative. However, the numerator compensates for this in reality and the gradient should be 0. For example,

import jax
import jax.numpy as jnp

x = jnp.array([1., 2., 3.])
jax.grad(jnp.linalg.norm)(x - x) #[nan nan nan]

l2_dist_sq = lambda x, y: jnp.linalg.norm(x - y) ** 2
jax.grad(l2_dist_sq)(x, x) #[nan nan nan]

# This should be the same as l2_dist_sq
oracle_l2_dist_sq = lambda x, y: (x - y).dot(x - y)
jax.grad(oracle_l2_dist_sq)(x, x) #[0 0 0]

Of course, I highly doubt jnp.linalg.norm is the only instance of this "bug", but I suspect this is a very commonly used function. It took me fairly long to deduce that my nans were coming from this gradient. I suppose a simple fix in this case would be something like

def norm(x, eps=1e-9):
    norm_sq = x.dot(x)
    return jnp.where(norm_sq < eps, 0, jnp.sqrt(norm_sq))

Having said that, perhaps there's a more conventional way of achieving this?

mattjj commented 3 years ago

Thanks for the question!

Actually, this is working as intended. The mathematical function represented by jnp.linalg.norm is not differentiable at zero, just like jnp.abs. In fact, jnp.linalg.norm applied to a real scalar is the same function as jnp.abs! As a result, differentiating it at zero should produce nans as an error value.

The squared norm is indeed a different function, which is mathematically differentiable at zero. Yet, as you've shown, l2_dist_sq is not correctly autodiff-able at zero, while oracle_l2_dist_sq is. That's also expected, and it's part of a more general phenomenon: different programming-language denotations of the same mathematical function might have different derivatives as programs. Such cases are failures of autodiff, but they're unavoidable for a NumPy-like language and an autodiff system which works 'locally' on programs (i.e. by structural induction). (It's possible to make programming languages in which autodiff always works, but only by making problematic programs impossible to write.)

Even though l2_dist_sq and oracle_l2_dist_sq represent the same mathematical function on real vectors yet their autodiff-computed gradients differ, I still consider this relatively good behavior because you don't get silently incorrect numbers. Instead, you get a clear error value.

One issue with setting a numerical threshold, as in the norm function you wrote, is that higher-order derivatives are wrong. That is, the second derivative of the norm-squared function is nonzero at zero, while the second derivative of that norm is zero.

Luckily, if you want to opt into that behavior, you can always write such a truncated-near-zero norm function yourself (as you did!). There are also other tools for controlling JAX's autodiff behavior to get exactly what you want. But we don't want to build in that behavior to JAX.

WDYT?

harwiltz commented 3 years ago

Thanks for this super detailed and interesting response! Admittedly I did not consider what would happen with higher order derivatives. Anyhow, that last link you posted seems to answer my remaining questions, and I agree with your conclusions.