Open emilyfertig opened 4 days ago
Thanks, @emilyfertig ! I noticed this a few months ago, started a PR to fix it, and then let it languish. This regressed when we did the jit/pjit merge more than a year ago. Let me see if I can revive the PR...
I noticed something else that might be related: the error message with debug_nans
used to say which line inside of a jitted function produced a nan, and now it just reports the call site. Here's an example from the NaN Debugging section of The Sharp Bits:
In [4]: from jax import jit
In [5]: @jit
...: def f(x, y):
...: a = x * y
...: b = (x + y) / (x - y)
...: c = a + 2
...: return a + b * c
...:
In [6]: x = jnp.array([2., 0.])
In [7]: y = jnp.array([3., 0.])
In [8]: f(x, y)
Invalid value encountered in the output of a jit function. Calling the de-optimized version.
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)
... stack trace ...
<ipython-input-5-619b39acbaac> in f(x, y)
2 def f(x, y):
3 a = x * y
----> 4 b = (x + y) / (x - y)
5 c = a + 2
6 return a + b * c
.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
343 return floor_divide(x1, x2)
344 else:
--> 345 return true_divide(x1, x2)
346
347
.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
332 x1, x2 = _promote_shapes(x1, x2)
333 return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334 lax.convert_element_type(x2, result_dtype))
335
336
.../jax/jax/lax.pyc in div(x, y)
244 def div(x, y):
245 r"""Elementwise division: :math:`x \over y`."""
--> 246 return div_p.bind(x, y)
247
248 def rem(x, y):
... stack trace ...
And here's the same code, run with 0.4.36:
---------------------------------------------------------------------------
FloatingPointError Traceback (most recent call last)
[... skipping hidden 1 frame]
[google3/third_party/py/jax/_src/profiler.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in wrapper(*args, **kwargs)
332 with TraceAnnotation(name, **decorator_kwargs):
--> 333 return func(*args, **kwargs)
334 return wrapper
4 frames
[google3/third_party/py/jax/_src/interpreters/pxla.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in __call__(self, *args)
1302 for arrays in out_arrays:
-> 1303 dispatch.check_special(self.name, arrays)
1304 out = self.out_handler(out_arrays)
[google3/third_party/py/jax/_src/dispatch.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in check_special(name, bufs)
315 for buf in bufs:
--> 316 _check_special(name, buf.dtype, buf)
317
[google3/third_party/py/jax/_src/dispatch.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in _check_special(name, dtype, buf)
320 if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
--> 321 raise FloatingPointError(f"invalid value (nan) encountered in {name}")
322 if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
FloatingPointError: invalid value (nan) encountered in jit(f)
During handling of the above exception, another exception occurred:
FloatingPointError Traceback (most recent call last)
[<ipython-input-18-9911e10902e9>](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in <cell line: 0>()
15 b = jnp.array([3., 9])
16
---> 17 print(f(x, y))
[... skipping hidden 3 frame]
[google3/third_party/py/jax/_src/pjit.py](https://colab.corp.google.com/drive/1wc-i4K-fZ5PeewAtC-Ojge-BVraWWmOI#) in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs, *args)
1692 "If you see this error, consider opening a bug report at "
1693 "https://github.com/jax-ml/jax.")
-> 1694 raise FloatingPointError(msg)
1695
1696
FloatingPointError: invalid value (nan) encountered in jit(f). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations.
If you see this error, consider opening a bug report at https://github.com/jax-ml/jax.
Also, the current version no longer prints "Invalid value encountered in the output of a jit function. Calling the de-optimized version." (sometimes it does, but I haven't figured out how to consistently reproduce it. I tried flushing the log buffer so I don't think it's that).
@mattjj If you have a start at a PR I'd be happy to take it over (especially if you think it'd be a good way to learn about this part of the code and wouldn't be too much to bite off as I'm getting ramped up).
The above behavior (printing the call site only and not the line in the function where the NaN occurred) is more recent. 0.4.35 (released 10/22) still prints the exact line.
For now #24989 comments out parts of the docs/error message that aren't consistent with how the code behaves.
Culprit for the second issue appears to be 32bf19ac6f52a0f6776c730dd352d0530cc5bc9f
Description
I'm working on documentation for
debug_nans
and I wrote the following function, which for certain input values callsjnp.log
on a negative number, producing anan
value.It fails with this error, indicating that a NaN was returned from the compiled function but not
fun.call_wrapped
. It's the same if I replacelog
withsqrt
, if I remove thejit
decorator, or if I just calljnp.log
on a negative value withoutjit
.The error message is misleading because NaNs are returned from the de-optimized functions as well, since it's taking the log of a negative value. I think something is going wrong in the code path taken in
_pjit_call_impl_python
but I can't tell what.cc @yashk2810 since it looks like you've worked on this area of the code a fair amount.
System info (python version, jaxlib version, accelerator, etc.)
Reproducible across a few different environments, but e.g.:
jax: 0.4.36 jaxlib: 0.4.36 numpy: 2.1.3 python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (f58ce1152703ca753794b8cef36da30bd2668d0f)] device info: Tesla V100-SXM2-16GB-1, 1 local devices" process_count: 1 platform: uname_result(system='Linux', node='b6e5614622812f47-3e7e1adbbf9.borgtask.google.com', release='5.10.0-smp-1104.53.0.0', version='#1 [v5.10.0-1104.53.0.0] SMP @1727505643', machine='x86_64')
$ nvidia-smi Mon Nov 18 12:42:04 2024
+---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 Tesla V100-SXM2-16GB Off | 00000000:B3:00.0 Off | 0 | | N/A 41C P0 72W / 300W | 12433MiB / 16384MiB | 0% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | 0 N/A N/A 829280 C ...fb3717c109/mount/server/ml_notebook 12430MiB | +---------------------------------------------------------------------------------------+