jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.28k stars 2.78k forks source link

[Pallas TPU] Error when negating a boolean value #24243

Open ayaka14732 opened 1 week ago

ayaka14732 commented 1 week ago

Description

import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

@functools.partial(
    pl.pallas_call,
    out_shape=jax.ShapeDtypeStruct((2,), jnp.bool_),
)
def kernel(x_ref, o_ref):
    o_ref[...] = jnp.logical_not(x_ref[...])

def main() -> None:
    x = jnp.array([False, True], dtype=jnp.bool_)
    out = kernel(x)
    print(out)

if __name__ == '__main__':
    main()

Error:

Traceback (most recent call last):
  File "/home/ayx/jax/2.py", line 19, in <module>
    main()
    ~~~~^^
  File "/home/ayx/jax/2.py", line 15, in main
    out = kernel(x)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Can't change bitwidth during a relayout

at location: loc(unknown)

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35.dev20241010+3bd8ca480
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.13.0 (main, Oct  8 2024, 01:04:00) [Clang 18.1.8 ]
device info: TPU v5 lite-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
justinjfu commented 4 days ago

This is caused by i1 splat not being supported by Mosaic. Logical not is lowered to xor(ones, x), but mosaic fails to create the array of ones properly for a boolean.

We should fix it there and not in Pallas.