Open agudym opened 8 months ago
Hi - thanks for the question! It looks like you're hitting a variant of this issue: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
Expressed in terms of where
rather than masking, your code is roughly equivalent to this:
def f(x, pre_filter=False):
delta = jnp.array((0, 1)) * x
if pre_filter:
return jnp.sum(jnp.where(delta > 0, delta, 0) ** 0.5)
else:
return jnp.sum(jnp.where(delta > 0, delta**0.5, 0))
In the first case, you only apply the square root to positive values. In the second case, you apply the square root to negative values, which generate NaNs. Outside autodiff, the filter works the same each way. But inside autodiff, the autodiff rule must consider the contributions of both filtered and non-filtered values to the gradient. You can read a more complete description of this at the link above.
Does that help answer your question?
Thanks so much for the swift response!
That's definitely a feature, not a bug, already described in details (for Tensorflow too), my bad! The most relevant comprehensive info (just refreshing the topic :) ): https://github.com/google/jax/issues/1052#issuecomment-514083352 https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
Summarizing all the stuff, my understanding is the following, the problem is:
def f(x): # Our "intuition" on how it SHOULD work
if x > 0:
return x**0.5 #sqrt(x)
else:
return Const # some stub
jax.grad(f)(0.0)
# 0 - and it really works CORRECTLY, i.e. we get "d Const / d x = 0" at "0"
However with jnp.where(or similar), we get "a surprize":
def f(x): # x.size = 1
return jnp.where(x > 0, x**0.5, Const)
jax.grad(f)(0.0)
# nan - booom!
So the workaround is to avoid Nans in ANY "execution branch" (even it's supposed to be filtered later by jnp.where(or similar):
def f(x): # x.size = 1
return jnp.where(x > 0, jnp.where(x > 0, x, 0)**0.5, Const)
jax.grad(f)(0.0)
# 0 - OK, i.e. we get "d Const / d x = 0" at "0"
Look like that...
p.s. Frankly speaking from the docs read it's still not obvious WHY it is so. Mentioned thing that "0 * nan = nan" isn't really a "theoretical limitation" to the problem imho. So I'm considering the problem to be an implementation "feature" (mb a known issue?). Or maybe a special NOTE with config.update("jax_debug_nans", True)
can appear if that's happened after jnp.where
?
Or maybe a special NOTE with config.update("jax_debug_nans", True) can appear if that's happened after jnp.where ?
Interesting idea! @jakevdp wdyt?
Mentioned thing that "0 * nan = nan" isn't really a "theoretical limitation" to the problem imho.
That is the fundamental root of the issue, but maybe we can connect the dots more concretely:
lambda x:T: ...
(think of T
as including the shape) must be a function which produces a value of type T
f = lambda y:f32[2]: y[1]
is lambda zbar:f32[]: jnp.zeros(2, 'f32').at[1].set(zbar)
(has to be a dense array because of Claim 1, and has to be a dense array of zeros because of Claim 2)g = lambda x:f32[]: c:f32[2] * x
for any constant c
is lambda ybar:f32[2]: (ybar * c:f32[2]).sum()
, where the sum arises from the broadcastlambda x: f(g(x))
is correct no matter the value of the constant c
so long as 0 * x = 0
(zero scaling) for all possible array entries x and 0 + x = x
(zero vector) for all possible array entries x, but if we had a nan value in c[0]
then the vjp of the composition is incorrectIndeed the vjp of just g
by itself always produces a nan value if c[0]
is nan, but it's not clear that that's a problem because there's a nan
in the output of g
. It's only when we compose it with f
, which drops the nan from the output, that it's clear things are really going wrong: just having nans in intermediates, not outputs, can break VJPs.
The last claim is really what we mean when we say the root of the issue is allowing some value x
for which x * 0 != 0
. If we didn't have such values, then this problem couldn't arise!
We might be able to fix this by changing Claim 1, basically by letting cotangents be sparse array types even when the primals are dense. But we've never gone down that path.
What do you think?
Interesting, thanks so much for the clarifications!
Lets check that I've got your point :)
$f(\boldsymbol{g}(x)) \in \mathbb{R}$ - some functions composition, with scalar $x \in \mathbb{R}$ input,
$\boldsymbol{g}(x) = (g_1(x), g_2(x))^T \in \mathbb{R}^{2 \times 1}$ - intermediate vector value,
with following derivatives:
$\frac{\partial\ f}{\partial\ \boldsymbol{g}}=\left (\frac{\partial\ f}{\partial\ g_1}, \frac{\partial\ f}{\partial\ g_2}\right ) \in \mathbb{R}^{1 \times 2}$ - gradient of scalar $f$, with corresponding vector-jacobian-product of the form: $\upsilon_g \left (\frac{\partial\ f}{\partial\ g_1}, \frac{\partial\ f}{\partial\ g_2}\right ) \in \mathbb{R}^{1 \times 2}$ with $\upsilon_g \in \mathbb{R}$
$\frac{\partial\ \boldsymbol{g}}{\partial\ x}=\left (\frac{\partial\ g_1}{\partial\ x}, \frac{\partial\ g_2}{\partial\ x}\right )^T \in \mathbb{R}^{2 \times 1}$ - gradient of vector $\boldsymbol{g}$, with corresponding vector-jacobian-product of the form: $\boldsymbol{\upsilon_f}^T \left (\frac{\partial\ g_1}{\partial\ x}, \frac{\partial\ g_2}{\partial\ x}\right )^T \in \mathbb{R}$ with $\boldsymbol{\upsilon_f} = (\upsilon_f^1 , \upsilon_f^2)^T \in \mathbb{R}^{2 \times 1}$
and finally
$\frac{\partial\ f}{\partial\ x} = \frac{\partial\ f}{\partial\ \boldsymbol{g}} \frac{\partial\ \boldsymbol{g}}{\partial\ x} = \frac{\partial\ f}{\partial\ g_1} \frac{\partial\ g_1}{\partial\ x} + \frac{\partial\ f}{\partial\ g_2} \frac{\partial\ g_2}{\partial\ x} \in \mathbb{R}$
I hope the notation is OK, and I've understood your clarifications correctly, so going straight to the point - the problem the ORIGINAL EXAMPLE is that because
$\frac{\partial\ g_1}{\partial\ x} = \frac{\partial\ \sqrt{x \cdot 0}}{x} = nan$
is being multiplied by derivative of our "filter-function" $f(x)=g_2(x)$ or even $f(x)=0 \cdot g_1(x) + g_2(x)$:
$\frac{\partial\ f}{\partial\ g_1} = 0$ (because $f$ doesn't depend on $g_1$), resulting in $\frac{\partial\ f}{\partial\ x} = 0 \cdot nan + ... = nan$. Is that correct ?
In other words, simplifying all the above stuff, one may assume that jnp.where
is kinda a "ReLU" function, for example:
$where(y > 0,\ y,\ 0) =f(y) = (y\ if\ y > 0\ else\ 0) = (y\ if\ y > 0\ else\ 0 \cdot y)$
and having to differentiate $f(y) = f(\sqrt{x}) = (\sqrt{x}\ \ if\ \sqrt{x} > 0\ \ else\ 0) = (\sqrt{x}\ \ if\ \sqrt{x} > 0\ \ else\ \ 0 \cdot \sqrt{x})$ we anyway end up with the "classic" chainrule:
$\frac{\partial\ f}{\partial\ x} = \frac{\partial\ f}{\partial\ y} \frac{\partial\ \sqrt{x}}{\partial\ x} = (1\ if\ y > 0\ else\ 0) \cdot \frac{\partial\ \sqrt{x}}{\partial\ x}$
and finally with the evaluation the above expression we get $\ 0 \cdot nan$ if $x = 0$. Does this make sense ?
So my NAIVE thought about "theoretical limitation" was based on assumption that current implementation of jnp.where
under the hood is something like the binded code
for(size_t i = 0; i < y.size(); ++i)
if(y[i].value > 0) // y[i] is sort of dual number
//get derivative of y
else
// ignore y[i], consider another constant value
Where we drop "nan-execution branch" with "if rather then 0-multiplication"
p.s. Merry Xmas Everybody!
Description
Hey, folks!
If I don't provide consistent data for gradient calculation, and try to filter the mess at later stages (pre_filter=False), I get NaN gradient while pre-filtering works OK. See the code below:
What jax/jaxlib version are you using?
0.4.23
Which accelerator(s) are you using?
CPU/GPU, doesn't matter
Additional system info?
Ubuntu 22.04, x86, Python 3.10.12
NVIDIA GPU info
No response