jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Apache License 2.0
30.06k stars 2.75k forks source link

[jax 0.4.33] XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Unsupported number of sorted inputs: 17 #23727

Open PhilipVinc opened 1 week ago

PhilipVinc commented 1 week ago


The following (very confusing) bug appeared in jax/lib 0.4.33 (jax/lib 0.4.31 works fine)

It appears for specific shapes.

I'm running it on Mac, but saw it on linux originally for shape (576, 16) as well.

import jax
import jax.numpy as jnp
S1 = (1,16)
a=jax.random.randint(jax.random.key(1), S1, 0, 1)
# Array([0], dtype=int32)
S2 = (2, 15)
a=jax.random.randint(jax.random.key(1), S2, 0, 1)
# Array([0, 1], dtype=int32)

S3 = (2, 16)
a=jax.random.randint(jax.random.key(1), S3, 0, 1)

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/pjit.py", line 332, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 2782, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 443, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/core.py", line 949, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/pjit.py", line 1739, in _pjit_call_impl
    return xc._xla.pjit(
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/pjit.py", line 1721, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/pjit.py", line 1675, in _pjit_call_impl_python
    return compiled.unsafe_call(*args), compiled
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
  File "/Users/filippo.vicentini/Documents/pythonenvs/netket/python-3.11.2/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1286, in __call__
    results = self.xla_executable.execute_sharded(input_bufs)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Unsupported number of sorted inputs: 17

System info (python version, jaxlib version, accelerator, etc.)

In [23]: import jax; jax.print_environment_info()
jax:    0.4.33
jaxlib: 0.4.33
numpy:  2.0.2
python: 3.11.2 (main, Apr  7 2023, 16:35:55) [Clang 14.0.3 (clang-1403.]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='mba-10834270.local', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:30 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6000', machine='arm64')

EDIT: It seems that sort_thunk supports only some hardcoded values


But I can't find the commit at which the implementation was switched to this.

There should be some sort of fallback?

hawkinsp commented 1 week ago

Try XLA_FLAGS=--xla_cpu_use_thunk_runtime=false as a workaround.

@ezhulenev @penpornk

I think I forgot to add it to the release notes for 0.4.32, but there was a major upgrade to the CPU backend. The flag above temporarily switches back to the old version.

PhilipVinc commented 1 week ago

I have a more reduced MWE

import jax.numpy as jnp
penpornk commented 1 week ago

Thank you for reporting the issue! Yes, the SortThunk needs a fallback kernel. We have filed a bug tracking this work.

In the meanwhile, I've added the specialization for 17 inputs in https://github.com/openxla/xla/commit/f237cc3aa69b9b40721482ca5b6fb8cfa623231e. This should reflect in JAX nightly once JAX updates their XLA commit.

PhilipVinc commented 1 week ago

Thanks @penpornk Though technically 17 was just a MWE for a reproducer. Technically for our application we reasonably need all values until 25. As you already added 25 would it be reasonable to add all values until 25?

may first this was confusing because we were seeing spurious failures only for some input sizes.

penpornk commented 1 week ago

As you already added 25 would it be reasonable to add all values until 25?

@PhilipVinc Sounds good. I've added support for up to 25 inputs in https://github.com/openxla/xla/commit/82deceb3ae7334357861146c0237bf8c51f244e3

PhilipVinc commented 1 week ago

Thank you!