Open PhilipVinc opened 1 week ago
Try XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
as a workaround.
@ezhulenev @penpornk
I think I forgot to add it to the release notes for 0.4.32, but there was a major upgrade to the CPU backend. The flag above temporarily switches back to the old version.
I have a more reduced MWE
import jax.numpy as jnp
jnp.lexsort(jnp.zeros((16,2)))
Thank you for reporting the issue! Yes, the SortThunk needs a fallback kernel. We have filed a bug tracking this work.
In the meanwhile, I've added the specialization for 17 inputs in https://github.com/openxla/xla/commit/f237cc3aa69b9b40721482ca5b6fb8cfa623231e. This should reflect in JAX nightly once JAX updates their XLA commit.
Thanks @penpornk Though technically 17 was just a MWE for a reproducer. Technically for our application we reasonably need all values until 25. As you already added 25 would it be reasonable to add all values until 25?
may first this was confusing because we were seeing spurious failures only for some input sizes.
As you already added 25 would it be reasonable to add all values until 25?
@PhilipVinc Sounds good. I've added support for up to 25 inputs in https://github.com/openxla/xla/commit/82deceb3ae7334357861146c0237bf8c51f244e3
Thank you!
Description
The following (very confusing) bug appeared in jax/lib 0.4.33 (jax/lib 0.4.31 works fine)
It appears for specific shapes.
I'm running it on Mac, but saw it on linux originally for shape (576, 16) as well.
System info (python version, jaxlib version, accelerator, etc.)
EDIT: It seems that sort_thunk supports only some hardcoded values
https://github.com/openxla/xla/blame/edd8e7f610ba15f4b0b3ae87c8e7b7d8f5c3dc9f/xla/backends/cpu/runtime/sort_thunk.cc#L478
But I can't find the commit at which the implementation was switched to this.
There should be some sort of fallback?