jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

[Pallas TPU] Unable to create a boolean array inside kernel due to Mosaic relayout error #24464

Open ayaka14732 opened 1 month ago

ayaka14732 commented 1 month ago

Description

This kernel can run successfully:

import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

@functools.partial(
    pl.pallas_call,
    out_shape=jax.ShapeDtypeStruct((4,), jnp.bool_),
    debug=True,
)
def kernel(x_ref, y_ref, o_ref):
    x = x_ref[...]
    y = y_ref[...]
    o_ref[...] = jnp.logical_xor(x, y)

def main():
    x = jnp.array([True, False, False, True], dtype=jnp.bool_)
    y = jnp.array([True, False, True, False], dtype=jnp.bool_)
    out = kernel(x, y)
    print(out)

if __name__ == '__main__':
    main()

However, if changing the kernel to:

def kernel(x_ref, y_ref, o_ref):
    x = x_ref[...]
    z = jnp.full((4,), True)
    o_ref[...] = jnp.logical_xor(x, z)

the code will fail:

The kernel jaxpr for pallas_call kernel at /home/ayx/jax/2.py:6:
{ lambda ; a:MemRef<None>{bool[4]} b:MemRef<None>{bool[4]} c:MemRef<None>{bool[4]}. let
    d:bool[4] <- a[:]
    e:bool[4] = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] True
    f:bool[4] = xor d e
    c[:] <- f
  in () }

The Mosaic module for pallas_call kernel at /home/ayx/jax/2.py:6:
module @kernel {
  func.func @main(%arg0: memref<4xi32, #tpu.memory_space<vmem>>, %arg1: memref<4xi32, #tpu.memory_space<vmem>>, %arg2: memref<4xi32, #tpu.memory_space<vmem>>) attributes {dimension_semantics = [], scalar_prefetch = 0 : i64, scratch_operands = 0 : i64} {
    %c0 = arith.constant 0 : index
    %0 = vector.load %arg0[%c0] : memref<4xi32, #tpu.memory_space<vmem>>, vector<4xi32>
    %cst = arith.constant dense<0> : vector<4xi32>
    %1 = arith.cmpi ne, %0, %cst : vector<4xi32>
    %true = arith.constant true
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %2 = arith.select %true, %c1_i32, %c0_i32 : i32
    %3 = vector.broadcast %2 : i32 to vector<4xi32>
    %c1_i32_0 = arith.constant 1 : i32
    %4 = vector.broadcast %c1_i32_0 : i32 to vector<4xi32>
    %5 = arith.cmpi eq, %3, %4 : vector<4xi32>
    %6 = arith.xori %1, %5 : vector<4xi1>
    %c0_1 = arith.constant 0 : index
    %7 = vector.load %arg2[%c0_1] : memref<4xi32, #tpu.memory_space<vmem>>, vector<4xi32>
    %8 = arith.extui %6 : vector<4xi1> to vector<4xi32>
    %cst_2 = arith.constant dense<0> : vector<4xi32>
    %9 = arith.cmpi ne, %7, %cst_2 : vector<4xi32>
    vector.store %8, %arg2[%c0_1] : memref<4xi32, #tpu.memory_space<vmem>>, vector<4xi32>
    return
  }
}

Traceback (most recent call last):
  File "/home/ayx/jax/2.py", line 23, in <module>
    main()
    ~~~~^^
  File "/home/ayx/jax/2.py", line 19, in main
    out = kernel(x, y)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Can't change bitwidth during a relayout

at location: loc(unknown)

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35.dev20241022+587832f29
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.13.0 (main, Oct  8 2024, 01:04:00) [Clang 18.1.8 ]
device info: TPU v5 lite-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
ayaka14732 commented 4 weeks ago

Reassigned internally.