Open AllanYangZhou opened 5 months ago
I also tried a simpler implementation without using scan but just a loop, which produces a different error:
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
def cumsum_kernel(xref_Cxd, carryref_d, oref_Cxd):
carry_d = carryref_d[...]
x_Cxd = xref_Cxd[...]
C, d = x_Cxd.shape
for i in range(C):
carry_d += x_Cxd[i]
oref_Cxd[i, :] = carry_d
carryref_d[...] = carry_d
@jax.jit
def cumsum(x_Txd):
T, d = x_Txd.shape
C = 128
carry_d = jnp.zeros((d,), dtype=x_Txd.dtype)
return pl.pallas_call(
cumsum_kernel,
grid=T // C,
in_specs=[
pl.BlockSpec(lambda i: (i, 0), (C, d)),
pl.BlockSpec(lambda i: 0, (d,)),
],
out_specs=pl.BlockSpec(lambda i: (i, 0), (C, d)),
out_shape=jax.ShapeDtypeStruct(x_Txd.shape, x_Txd.dtype),
interpret=False,
)(x_Txd, carry_d)
key = jax.random.PRNGKey(0)
x_Txd = jax.random.normal(key, (1280, 64))
cs_Txd = cumsum(x_Txd)
realcs_Txd = jnp.cumsum(x_Txd, axis=0)
print(f"Error is {jnp.max(jnp.abs(cs_Txd - realcs_Txd))}")
This produces a different error:
Traceback (most recent call last):
File "/home/allanzhou/pallas/simple_pallas_cumsum.py", line 35, in <module>
cs_Txd = cumsum(x_Txd)
File "/home/allanzhou/pallas/simple_pallas_cumsum.py", line 21, in cumsum
return pl.pallas_call(
File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 456, in wrapped
out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: Internal TPU kernel compiler error: Only zero-offset slices supported.
The MLIR operation involved:
%137 = "vector.extract_strided_slice"(%132) <{offsets = [1, 0], sizes = [1, 64], strides = [1, 1]}> : (vector<128x64xf32>) -> vector<1x64xf32>
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/allanzhou/pallas/simple_pallas_cumsum.py", line 35, in <module>
cs_Txd = cumsum(x_Txd)
File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 95, in pallas_call_tpu_lowering_rule
return mlir.lower_fun(_lower_fun, multiple_results=True)(
File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 82, in _lower_fun
return mosaic.as_tpu_kernel(
File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 406, in as_tpu_kernel
lowered_module_asm, constants = _lower_tpu_kernel(
File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 333, in _lower_tpu_kernel
_run_pass_pipeline(pipeline, module, "post-infer-vector-layout")
File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 260, in _run_pass_pipeline
raise RuntimeError("\n".join(msg)) from None
RuntimeError: Internal TPU kernel compiler error: Only zero-offset slices supported.
The MLIR operation involved:
%137 = "vector.extract_strided_slice"(%132) <{offsets = [1, 0], sizes = [1, 64], strides = [1, 1]}> : (vector<128x64xf32>) -> vector<1x64xf32>
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
Description
I tried to write a simple TPU pallas kernel to implement cum_sum, where the input is being chunked along the summed dimension and the kernel is calculating one chunk at a time. I currently get the below error. It is being triggered because the
num_extensive
variable evaluates to True, though I don't understand what extensive means here.The code to reproduce is below. Note that if I turn
interpret=True
the code runs without error.System info (python version, jaxlib version, accelerator, etc.)
I am using a Cloud TPU v3-8 and using Jax Version: 0.4.23.