jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.56k stars 2.81k forks source link

`debug_nans` error always says the de-optimized function did not produce NaNs #24955

Open emilyfertig opened 4 days ago

emilyfertig commented 4 days ago

Description

I'm working on documentation for debug_nans and I wrote the following function, which for certain input values calls jnp.log on a negative number, producing a nan value.

import jax
import jax.numpy as jnp
jax.config.update("jax_debug_nans", True)

@jax.jit
def f(x, y):
  w = jnp.sin(x) - y**2
  z = jnp.log(w)
  return z*2

print(f(0.5, 0))
print(f(-2., 5))  # ==> FloatingPointError with note that the NaN doesn't appear without jit

print(jnp.log(-5.) ) # ==> same error with note

It fails with this error, indicating that a NaN was returned from the compiled function but not fun.call_wrapped. It's the same if I replace log with sqrt, if I remove the jit decorator, or if I just call jnp.log on a negative value without jit.

The error message is misleading because NaNs are returned from the de-optimized functions as well, since it's taking the log of a negative value. I think something is going wrong in the code path taken in _pjit_call_impl_python but I can't tell what.

cc @yashk2810 since it looks like you've worked on this area of the code a fair amount.

System info (python version, jaxlib version, accelerator, etc.)

Reproducible across a few different environments, but e.g.:

jax: 0.4.36 jaxlib: 0.4.36 numpy: 2.1.3 python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (f58ce1152703ca753794b8cef36da30bd2668d0f)] device info: Tesla V100-SXM2-16GB-1, 1 local devices" process_count: 1 platform: uname_result(system='Linux', node='b6e5614622812f47-3e7e1adbbf9.borgtask.google.com', release='5.10.0-smp-1104.53.0.0', version='#1 [v5.10.0-1104.53.0.0] SMP @1727505643', machine='x86_64')

$ nvidia-smi Mon Nov 18 12:42:04 2024
+---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 Tesla V100-SXM2-16GB Off | 00000000:B3:00.0 Off | 0 | | N/A 41C P0 72W / 300W | 12433MiB / 16384MiB | 0% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | 0 N/A N/A 829280 C ...fb3717c109/mount/server/ml_notebook 12430MiB | +---------------------------------------------------------------------------------------+

mattjj commented 4 days ago

Thanks, @emilyfertig ! I noticed this a few months ago, started a PR to fix it, and then let it languish. This regressed when we did the jit/pjit merge more than a year ago. Let me see if I can revive the PR...

emilyfertig commented 3 days ago

I noticed something else that might be related: the error message with debug_nans used to say which line inside of a jitted function produced a nan, and now it just reports the call site. Here's an example from the NaN Debugging section of The Sharp Bits:

In [4]: from jax import jit

In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:

In [6]: x = jnp.array([2., 0.])

In [7]: y = jnp.array([3., 0.])

In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)

 ... stack trace ...

<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc in div(x, y)
    244 def div(x, y):
    245   r"""Elementwise division: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

 ... stack trace ...

And here's the same code, run with 0.4.36:

---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
    [... skipping hidden 1 frame]

[google3/third_party/py/jax/_src/profiler.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in wrapper(*args, **kwargs)
    332     with TraceAnnotation(name, **decorator_kwargs):
--> 333       return func(*args, **kwargs)
    334     return wrapper

4 frames
[google3/third_party/py/jax/_src/interpreters/pxla.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in __call__(self, *args)
   1302         for arrays in out_arrays:
-> 1303           dispatch.check_special(self.name, arrays)
   1304         out = self.out_handler(out_arrays)

[google3/third_party/py/jax/_src/dispatch.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in check_special(name, bufs)
    315     for buf in bufs:
--> 316       _check_special(name, buf.dtype, buf)
    317 

[google3/third_party/py/jax/_src/dispatch.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in _check_special(name, dtype, buf)
    320     if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
--> 321       raise FloatingPointError(f"invalid value (nan) encountered in {name}")
    322     if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):

FloatingPointError: invalid value (nan) encountered in jit(f)

During handling of the above exception, another exception occurred:

FloatingPointError                        Traceback (most recent call last)
[<ipython-input-18-9911e10902e9>](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in <cell line: 0>()
     15 b = jnp.array([3., 9])
     16 
---> 17 print(f(x, y))

    [... skipping hidden 3 frame]

[google3/third_party/py/jax/_src/pjit.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs, *args)
   1692            "If you see this error, consider opening a bug report at "
   1693            "https://github.com/jax-ml/jax.")
-> 1694     raise FloatingPointError(msg)
   1695 
   1696 

FloatingPointError: invalid value (nan) encountered in jit(f). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/jax-ml/jax.

Also, the current version no longer prints "Invalid value encountered in the output of a jit function. Calling the de-optimized version." (sometimes it does, but I haven't figured out how to consistently reproduce it. I tried flushing the log buffer so I don't think it's that).

@mattjj If you have a start at a PR I'd be happy to take it over (especially if you think it'd be a good way to learn about this part of the code and wouldn't be too much to bite off as I'm getting ramped up).

emilyfertig commented 3 days ago

The above behavior (printing the call site only and not the line in the function where the NaN occurred) is more recent. 0.4.35 (released 10/22) still prints the exact line.

emilyfertig commented 3 days ago

For now #24989 comments out parts of the docs/error message that aren't consistent with how the code behaves.

emilyfertig commented 3 days ago

Culprit for the second issue appears to be 32bf19ac6f52a0f6776c730dd352d0530cc5bc9f