Closed harwiltz closed 3 years ago
Thanks for the question!
Actually, this is working as intended. The mathematical function represented by jnp.linalg.norm
is not differentiable at zero, just like jnp.abs
. In fact, jnp.linalg.norm
applied to a real scalar is the same function as jnp.abs
! As a result, differentiating it at zero should produce nans as an error value.
The squared norm is indeed a different function, which is mathematically differentiable at zero. Yet, as you've shown, l2_dist_sq
is not correctly autodiff-able at zero, while oracle_l2_dist_sq
is. That's also expected, and it's part of a more general phenomenon: different programming-language denotations of the same mathematical function might have different derivatives as programs. Such cases are failures of autodiff, but they're unavoidable for a NumPy-like language and an autodiff system which works 'locally' on programs (i.e. by structural induction). (It's possible to make programming languages in which autodiff always works, but only by making problematic programs impossible to write.)
Even though l2_dist_sq
and oracle_l2_dist_sq
represent the same mathematical function on real vectors yet their autodiff-computed gradients differ, I still consider this relatively good behavior because you don't get silently incorrect numbers. Instead, you get a clear error value.
One issue with setting a numerical threshold, as in the norm
function you wrote, is that higher-order derivatives are wrong. That is, the second derivative of the norm-squared function is nonzero at zero, while the second derivative of that norm
is zero.
Luckily, if you want to opt into that behavior, you can always write such a truncated-near-zero norm
function yourself (as you did!). There are also other tools for controlling JAX's autodiff behavior to get exactly what you want. But we don't want to build in that behavior to JAX.
WDYT?
Thanks for this super detailed and interesting response! Admittedly I did not consider what would happen with higher order derivatives. Anyhow, that last link you posted seems to answer my remaining questions, and I agree with your conclusions.
Please:
I'm not sure if there's an elegant way to fix this. I noticed that when I differentiate
jnp.linalg.norm
when the input is0
, the gradient isnan
. I believe this is due to the vanishing denominator in the derivative. However, the numerator compensates for this in reality and the gradient should be0
. For example,Of course, I highly doubt
jnp.linalg.norm
is the only instance of this "bug", but I suspect this is a very commonly used function. It took me fairly long to deduce that mynan
s were coming from this gradient. I suppose a simple fix in this case would be something likeHaving said that, perhaps there's a more conventional way of achieving this?