google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

"NaNs Leakeage" - can not remove innconsistent NaN-grad elements #19089

Open agudym opened 8 months ago

agudym commented 8 months ago

Description

Hey, folks!

If I don't provide consistent data for gradient calculation, and try to filter the mess at later stages (pre_filter=False), I get NaN gradient while pre-filtering works OK. See the code below:

    def f(x, pre_filter=False):
        delta = jnp.array((0, 1)) * x
        if pre_filter:
            return jnp.sum( delta[delta > 0]**0.5 )
        else:
            return jnp.sum( (delta**0.5)[delta > 0] )
    print(grad(f)(1.))
    # NaN if pre_filter == False
    # 0.5 if pre_filter == True

What jax/jaxlib version are you using?

0.4.23

Which accelerator(s) are you using?

CPU/GPU, doesn't matter

Additional system info?

Ubuntu 22.04, x86, Python 3.10.12

NVIDIA GPU info

No response

jakevdp commented 8 months ago

Hi - thanks for the question! It looks like you're hitting a variant of this issue: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where

Expressed in terms of where rather than masking, your code is roughly equivalent to this:

def f(x, pre_filter=False):
    delta = jnp.array((0, 1)) * x
    if pre_filter:
        return jnp.sum(jnp.where(delta > 0, delta, 0) ** 0.5)
    else:
        return jnp.sum(jnp.where(delta > 0, delta**0.5, 0))

In the first case, you only apply the square root to positive values. In the second case, you apply the square root to negative values, which generate NaNs. Outside autodiff, the filter works the same each way. But inside autodiff, the autodiff rule must consider the contributions of both filtered and non-filtered values to the gradient. You can read a more complete description of this at the link above.

Does that help answer your question?

agudym commented 8 months ago

Thanks so much for the swift response!

That's definitely a feature, not a bug, already described in details (for Tensorflow too), my bad! The most relevant comprehensive info (just refreshing the topic :) ): https://github.com/google/jax/issues/1052#issuecomment-514083352 https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf

Summarizing all the stuff, my understanding is the following, the problem is:

def f(x): # Our "intuition" on how it SHOULD work
  if x > 0:
    return x**0.5 #sqrt(x)
  else:
    return Const # some stub
jax.grad(f)(0.0)
# 0 - and it really works CORRECTLY, i.e. we get "d Const / d x = 0" at "0"

However with jnp.where(or similar), we get "a surprize":

def f(x): # x.size = 1
   return jnp.where(x > 0, x**0.5, Const)
jax.grad(f)(0.0)
# nan - booom!

So the workaround is to avoid Nans in ANY "execution branch" (even it's supposed to be filtered later by jnp.where(or similar):

def f(x): # x.size = 1
   return jnp.where(x > 0, jnp.where(x > 0, x, 0)**0.5, Const)
jax.grad(f)(0.0)
# 0 - OK, i.e. we get "d Const / d x = 0" at "0"

Look like that...

p.s. Frankly speaking from the docs read it's still not obvious WHY it is so. Mentioned thing that "0 * nan = nan" isn't really a "theoretical limitation" to the problem imho. So I'm considering the problem to be an implementation "feature" (mb a known issue?). Or maybe a special NOTE with config.update("jax_debug_nans", True) can appear if that's happened after jnp.where ?

mattjj commented 8 months ago

Or maybe a special NOTE with config.update("jax_debug_nans", True) can appear if that's happened after jnp.where ?

Interesting idea! @jakevdp wdyt?

Mentioned thing that "0 * nan = nan" isn't really a "theoretical limitation" to the problem imho.

That is the fundamental root of the issue, but maybe we can connect the dots more concretely:

  1. the vjp of any function lambda x:T: ... (think of T as including the shape) must be a function which produces a value of type T
  2. the entry of a vjp value corresponding to an input or intermediate for which the value doesn't affect the output must be zero
  3. the vjp of f = lambda y:f32[2]: y[1] is lambda zbar:f32[]: jnp.zeros(2, 'f32').at[1].set(zbar) (has to be a dense array because of Claim 1, and has to be a dense array of zeros because of Claim 2)
  4. the vjp of g = lambda x:f32[]: c:f32[2] * x for any constant c is lambda ybar:f32[2]: (ybar * c:f32[2]).sum(), where the sum arises from the broadcast
  5. the vjp of the composed function lambda x: f(g(x)) is correct no matter the value of the constant c so long as 0 * x = 0 (zero scaling) for all possible array entries x and 0 + x = x (zero vector) for all possible array entries x, but if we had a nan value in c[0] then the vjp of the composition is incorrect

Indeed the vjp of just g by itself always produces a nan value if c[0] is nan, but it's not clear that that's a problem because there's a nan in the output of g. It's only when we compose it with f, which drops the nan from the output, that it's clear things are really going wrong: just having nans in intermediates, not outputs, can break VJPs.

The last claim is really what we mean when we say the root of the issue is allowing some value x for which x * 0 != 0. If we didn't have such values, then this problem couldn't arise!

We might be able to fix this by changing Claim 1, basically by letting cotangents be sparse array types even when the primals are dense. But we've never gone down that path.

What do you think?

agudym commented 8 months ago

Interesting, thanks so much for the clarifications!

Lets check that I've got your point :)

$f(\boldsymbol{g}(x)) \in \mathbb{R}$ - some functions composition, with scalar $x \in \mathbb{R}$ input,

$\boldsymbol{g}(x) = (g_1(x), g_2(x))^T \in \mathbb{R}^{2 \times 1}$ - intermediate vector value,

with following derivatives:

$\frac{\partial\ f}{\partial\ \boldsymbol{g}}=\left (\frac{\partial\ f}{\partial\ g_1}, \frac{\partial\ f}{\partial\ g_2}\right ) \in \mathbb{R}^{1 \times 2}$ - gradient of scalar $f$, with corresponding vector-jacobian-product of the form: $\upsilon_g \left (\frac{\partial\ f}{\partial\ g_1}, \frac{\partial\ f}{\partial\ g_2}\right ) \in \mathbb{R}^{1 \times 2}$ with $\upsilon_g \in \mathbb{R}$

$\frac{\partial\ \boldsymbol{g}}{\partial\ x}=\left (\frac{\partial\ g_1}{\partial\ x}, \frac{\partial\ g_2}{\partial\ x}\right )^T \in \mathbb{R}^{2 \times 1}$ - gradient of vector $\boldsymbol{g}$, with corresponding vector-jacobian-product of the form: $\boldsymbol{\upsilon_f}^T \left (\frac{\partial\ g_1}{\partial\ x}, \frac{\partial\ g_2}{\partial\ x}\right )^T \in \mathbb{R}$ with $\boldsymbol{\upsilon_f} = (\upsilon_f^1 , \upsilon_f^2)^T \in \mathbb{R}^{2 \times 1}$

and finally

$\frac{\partial\ f}{\partial\ x} = \frac{\partial\ f}{\partial\ \boldsymbol{g}} \frac{\partial\ \boldsymbol{g}}{\partial\ x} = \frac{\partial\ f}{\partial\ g_1} \frac{\partial\ g_1}{\partial\ x} + \frac{\partial\ f}{\partial\ g_2} \frac{\partial\ g_2}{\partial\ x} \in \mathbb{R}$

I hope the notation is OK, and I've understood your clarifications correctly, so going straight to the point - the problem the ORIGINAL EXAMPLE is that because

$\frac{\partial\ g_1}{\partial\ x} = \frac{\partial\ \sqrt{x \cdot 0}}{x} = nan$

is being multiplied by derivative of our "filter-function" $f(x)=g_2(x)$ or even $f(x)=0 \cdot g_1(x) + g_2(x)$:

$\frac{\partial\ f}{\partial\ g_1} = 0$ (because $f$ doesn't depend on $g_1$), resulting in $\frac{\partial\ f}{\partial\ x} = 0 \cdot nan + ... = nan$. Is that correct ?

In other words, simplifying all the above stuff, one may assume that jnp.where is kinda a "ReLU" function, for example:

$where(y > 0,\ y,\ 0) =f(y) = (y\ if\ y > 0\ else\ 0) = (y\ if\ y > 0\ else\ 0 \cdot y)$

and having to differentiate $f(y) = f(\sqrt{x}) = (\sqrt{x}\ \ if\ \sqrt{x} > 0\ \ else\ 0) = (\sqrt{x}\ \ if\ \sqrt{x} > 0\ \ else\ \ 0 \cdot \sqrt{x})$ we anyway end up with the "classic" chainrule:

$\frac{\partial\ f}{\partial\ x} = \frac{\partial\ f}{\partial\ y} \frac{\partial\ \sqrt{x}}{\partial\ x} = (1\ if\ y > 0\ else\ 0) \cdot \frac{\partial\ \sqrt{x}}{\partial\ x}$

and finally with the evaluation the above expression we get $\ 0 \cdot nan$ if $x = 0$. Does this make sense ?

So my NAIVE thought about "theoretical limitation" was based on assumption that current implementation of jnp.where under the hood is something like the binded code

for(size_t i = 0; i < y.size(); ++i)
   if(y[i].value > 0) // y[i] is sort of dual number
     //get derivative of y
   else
    // ignore y[i], consider another constant value

Where we drop "nan-execution branch" with "if rather then 0-multiplication"

p.s. Merry Xmas Everybody!