import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct((2,), jnp.bool_),
)
def kernel(x_ref, o_ref):
o_ref[...] = jnp.logical_not(x_ref[...])
def main() -> None:
x = jnp.array([False, True], dtype=jnp.bool_)
out = kernel(x)
print(out)
if __name__ == '__main__':
main()
Error:
Traceback (most recent call last):
File "/home/ayx/jax/2.py", line 19, in <module>
main()
~~~~^^
File "/home/ayx/jax/2.py", line 15, in main
out = kernel(x)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Can't change bitwidth during a relayout
at location: loc(unknown)
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35.dev20241010+3bd8ca480
jaxlib: 0.4.34
numpy: 2.1.2
python: 3.13.0 (main, Oct 8 2024, 01:04:00) [Clang 18.1.8 ]
device info: TPU v5 lite-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
This is caused by i1 splat not being supported by Mosaic. Logical not is lowered to xor(ones, x), but mosaic fails to create the array of ones properly for a boolean.
Description
Error:
System info (python version, jaxlib version, accelerator, etc.)