jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.21k stars 2.77k forks source link

Pallas+scan: NotImplementedError when num_extensive is True #21190

Open AllanYangZhou opened 5 months ago

AllanYangZhou commented 5 months ago

Description

I tried to write a simple TPU pallas kernel to implement cum_sum, where the input is being chunked along the summed dimension and the kernel is calculating one chunk at a time. I currently get the below error. It is being triggered because the num_extensive variable evaluates to True, though I don't understand what extensive means here.

Traceback (most recent call last):
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 549, in jaxpr_subcomp
    ans = lowering_rules[eqn.primitive](
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1516, in _scan_lowering_rule
    if num_extensive: raise NotImplementedError
NotImplementedError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/allanzhou/pallas/simple_pallas_scan.py", line 41, in <module>
    carry_d, cs_Txd = cumsum(x_Txd)
  File "/home/allanzhou/pallas/simple_pallas_scan.py", line 24, in cumsum
    return pl.pallas_call(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 452, in wrapped
    grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 385, in _trace_to_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals,
  File "/home/allanzhou/pallas/simple_pallas_scan.py", line 15, in cumsum_kernel
    carry_d, cs_Cxd = jax.lax.scan(inner, carry_d, x_Cxd, length=x_Cxd.shape[0])
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax._src.pallas.mosaic.lowering.LoweringException: Exception while lowering eqn:
  a:f32[256] b:f32[128,256] = scan[
  jaxpr={ lambda ; c:f32[256] d:f32[256]. let e:f32[256] = add c d in (e, e) }
  length=128
  linear=(False, False)
  num_carry=1
  num_consts=0
  reverse=False
  unroll=1
] f g
With context:
  LoweringRuleContext(lowering_context=LoweringContext(ir_context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f971e902f70>, grid_indices=(<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7f971e8e1b70>,), block_shapes=[(128, 256), (256,), (128, 256)], name_stack=NameStack(stack=()), mesh_context=None), avals_in=[ShapedArray(float32[256]), ShapedArray(float32[128,256])], avals_out=[ShapedArray(float32[256]), ShapedArray(float32[128,256])], block_shapes=[None, None])
With inval shapes=[None, None]
With inval types=[VectorType(vector<256xf32>), VectorType(vector<128x256xf32>)]
In jaxpr:
{ lambda ; a:Ref{float32[128,256]} b:Ref{float32[256]} c:Ref{float32[128,256]}. let
    d:i32[] = program_id[axis=0] 
    e:bool[] = eq d 0
    f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    cond[
      branches=(
        { lambda ; g_:Ref{float32[256]}. let  in () }
        { lambda ; h:Ref{float32[256]}. let
            i:f32[256] = broadcast_in_dim[broadcast_dimensions=() shape=(256,)] 0.0
            h[:] <- i
          in () }
      )
      linear=(False,)
    ] f b
    j:f32[256] <- b[:]
    k:f32[128,256] <- a[:,:]
    l:f32[256] m:f32[128,256] = scan[
      jaxpr={ lambda ; n:f32[256] o:f32[256]. let
          p:f32[256] = add n o
        in (p, p) }
      length=128
      linear=(False, False)
      num_carry=1
      num_consts=0
      reverse=False
      unroll=1
    ] j k
    b[:] <- l
    c[:,:] <- m
  in () }

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/allanzhou/pallas/simple_pallas_scan.py", line 41, in <module>
    carry_d, cs_Txd = cumsum(x_Txd)
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 75, in pallas_call_tpu_lowering_rule
    mosaic_module, extra_args = lowering.lower_jaxpr_to_module(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 340, in lower_jaxpr_to_module
    func_op = lower_jaxpr_to_func(ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping,
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 478, in lower_jaxpr_to_func
    body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jaxlib/mlir/dialects/func.py", line 195, in decorator
    return_values = f(*func_args, **func_kwargs)
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 474, in body_func
    return jaxpr_subcomp(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/lowering.py", line 555, in jaxpr_subcomp
    raise LoweringException(
jax._src.pallas.mosaic.lowering.LoweringException: Exception while lowering eqn:
  a:f32[256] b:f32[128,256] = scan[
  jaxpr={ lambda ; c:f32[256] d:f32[256]. let e:f32[256] = add c d in (e, e) }
  length=128
  linear=(False, False)
  num_carry=1
  num_consts=0
  reverse=False
  unroll=1
] f g
With context:
  LoweringRuleContext(lowering_context=LoweringContext(ir_context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f971e902f70>, grid_indices=(<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7f971e8e1b70>,), block_shapes=[(128, 256), (256,), (128, 256)], name_stack=NameStack(stack=()), mesh_context=None), avals_in=[ShapedArray(float32[256]), ShapedArray(float32[128,256])], avals_out=[ShapedArray(float32[256]), ShapedArray(float32[128,256])], block_shapes=[None, None])
With inval shapes=[None, None]
With inval types=[VectorType(vector<256xf32>), VectorType(vector<128x256xf32>)]
In jaxpr:
{ lambda ; a:Ref{float32[128,256]} b:Ref{float32[256]} c:Ref{float32[128,256]}. let
    d:i32[] = program_id[axis=0] 
    e:bool[] = eq d 0
    f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    cond[
      branches=(
        { lambda ; g_:Ref{float32[256]}. let  in () }
        { lambda ; h:Ref{float32[256]}. let
            i:f32[256] = broadcast_in_dim[broadcast_dimensions=() shape=(256,)] 0.0
            h[:] <- i
          in () }
      )
      linear=(False,)
    ] f b
    j:f32[256] <- b[:]
    k:f32[128,256] <- a[:,:]
    l:f32[256] m:f32[128,256] = scan[
      jaxpr={ lambda ; n:f32[256] o:f32[256]. let
          p:f32[256] = add n o
        in (p, p) }
      length=128
      linear=(False, False)
      num_carry=1
      num_consts=0
      reverse=False
      unroll=1
    ] j k
    b[:] <- l
    c[:,:] <- m
  in () }

The code to reproduce is below. Note that if I turn interpret=True the code runs without error.

from jax.experimental import pallas as pl
import jax.numpy as jnp

def cumsum_kernel(xref_Cxd, carryref_d, oref_Cxd):
    @pl.when(pl.program_id(axis=0) == 0)
    def _():
        carryref_d[...] = jnp.zeros_like(carryref_d)
    carry_d = carryref_d[...]
    x_Cxd = xref_Cxd[...]
    def inner(_carry_d, x_d):
        _carry_d = _carry_d + x_d
        return _carry_d, _carry_d
    carry_d, cs_Cxd = jax.lax.scan(inner, carry_d, x_Cxd, length=x_Cxd.shape[0])
    carryref_d[...] = carry_d
    oref_Cxd[...] = cs_Cxd

@jax.jit
def cumsum(x_Txd):
    T, d = x_Txd.shape
    C = 128
    return pl.pallas_call(
        cumsum_kernel,
        grid=T // C,
        in_specs=[pl.BlockSpec(lambda i: (i, 0), (C, d))],
        out_specs=[
            pl.BlockSpec(lambda i: 0, (d,)),  # carry
            pl.BlockSpec(lambda i: (i, 0), (C, d))  # out
        ],
        out_shape=[
            jax.ShapeDtypeStruct((d,), x_Txd.dtype),
            jax.ShapeDtypeStruct(x_Txd.shape, x_Txd.dtype),
        ],
        interpret=False
    )(x_Txd)

key = jax.random.PRNGKey(0)
x_Txd = jax.random.normal(key, (1280, 256))
carry_d, cs_Txd = cumsum(x_Txd)
realcs_Txd = jnp.cumsum(x_Txd, axis=0)
print(f"Error is {jnp.max(jnp.abs(cs_Txd - realcs_Txd))}")

System info (python version, jaxlib version, accelerator, etc.)

I am using a Cloud TPU v3-8 and using Jax Version: 0.4.23.

AllanYangZhou commented 5 months ago

I also tried a simpler implementation without using scan but just a loop, which produces a different error:

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

def cumsum_kernel(xref_Cxd, carryref_d, oref_Cxd):
    carry_d = carryref_d[...]
    x_Cxd = xref_Cxd[...]
    C, d = x_Cxd.shape
    for i in range(C):
        carry_d += x_Cxd[i]
        oref_Cxd[i, :] = carry_d
    carryref_d[...] = carry_d

@jax.jit
def cumsum(x_Txd):
    T, d = x_Txd.shape
    C = 128
    carry_d = jnp.zeros((d,), dtype=x_Txd.dtype)
    return pl.pallas_call(
        cumsum_kernel,
        grid=T // C,
        in_specs=[
            pl.BlockSpec(lambda i: (i, 0), (C, d)),
            pl.BlockSpec(lambda i: 0, (d,)),
        ],
        out_specs=pl.BlockSpec(lambda i: (i, 0), (C, d)),
        out_shape=jax.ShapeDtypeStruct(x_Txd.shape, x_Txd.dtype),
        interpret=False,
    )(x_Txd, carry_d)

key = jax.random.PRNGKey(0)
x_Txd = jax.random.normal(key, (1280, 64))
cs_Txd = cumsum(x_Txd)
realcs_Txd = jnp.cumsum(x_Txd, axis=0)
print(f"Error is {jnp.max(jnp.abs(cs_Txd - realcs_Txd))}")

This produces a different error:

Traceback (most recent call last):
  File "/home/allanzhou/pallas/simple_pallas_cumsum.py", line 35, in <module>
    cs_Txd = cumsum(x_Txd)
  File "/home/allanzhou/pallas/simple_pallas_cumsum.py", line 21, in cumsum
    return pl.pallas_call(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 456, in wrapped
    out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: Internal TPU kernel compiler error: Only zero-offset slices supported.

The MLIR operation involved:
  %137 = "vector.extract_strided_slice"(%132) <{offsets = [1, 0], sizes = [1, 64], strides = [1, 1]}> : (vector<128x64xf32>) -> vector<1x64xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/allanzhou/pallas/simple_pallas_cumsum.py", line 35, in <module>
    cs_Txd = cumsum(x_Txd)
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 95, in pallas_call_tpu_lowering_rule
    return mlir.lower_fun(_lower_fun, multiple_results=True)(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 82, in _lower_fun
    return mosaic.as_tpu_kernel(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 406, in as_tpu_kernel
    lowered_module_asm, constants = _lower_tpu_kernel(
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 333, in _lower_tpu_kernel
    _run_pass_pipeline(pipeline, module, "post-infer-vector-layout")
  File "/home/allanzhou/mid/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 260, in _run_pass_pipeline
    raise RuntimeError("\n".join(msg)) from None
RuntimeError: Internal TPU kernel compiler error: Only zero-offset slices supported.

The MLIR operation involved:
  %137 = "vector.extract_strided_slice"(%132) <{offsets = [1, 0], sizes = [1, 64], strides = [1, 1]}> : (vector<128x64xf32>) -> vector<1x64xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke