jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.63k stars 2.82k forks source link

[Pallas] Internal TPU kernel compiler error: unsupported operand shapes or layouts tpu.matmul #18388

Open demon2036 opened 1 year ago

demon2036 commented 1 year ago

Description

When I use TPU, I attempt to implement tiling operation using Pallas, where q is tiled as (block_q, block_d) and v is tiled as (block_k, block_d). I use pl.dot(q, k, trans_b=True) to achieve q@k.T; however, I've noticed that if block_k is less than 128, it results in an error. In my case, q is a vector <256x128xf32> and k is a vector <16x128xf32> and the result be vector<256x16xf32>

Error Information

Traceback (most recent call last): File "/root/test_multihead.py", line 101, in res = matrix_mul(x, x, x, ) File "/root/test_multihead.py", line 73, in matrix_mul return pl.pallas_call( File "/usr/local/lib/python3.10/dist-packages/jax/_src/pallas/pallas_call.py", line 383, in wrapped out_flat = pallas_call_p.bind( jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: Internal TPU kernel compiler error: unsupported operand shapes or layouts

The MLIR operation involved: %26 = "tpu.matmul"(%23, %25, %17) {transpose_lhs = false, transpose_rhs = true} : (vector<256x128xf32>, vector<16x128xf32>, vector<256x16xf32>) -> vector<256x16xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.


The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/test_multihead.py", line 101, in res = matrix_mul(x, x, x, ) File "/usr/local/lib/python3.10/dist-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 88, in pallas_call_tpu_lowering_rule return mlir.lower_fun(_lower_fun, multiple_results=True)( File "/usr/local/lib/python3.10/dist-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 79, in _lower_fun return mosaic.as_tpu_kernel( File "/usr/local/lib/python3.10/dist-packages/jax/_src/tpu_custom_call.py", line 408, in as_tpu_kernel lowered_module_asm, constants = _lower_tpu_kernel( File "/usr/local/lib/python3.10/dist-packages/jax/_src/tpu_custom_call.py", line 335, in _lower_tpu_kernel _run_pass_pipeline(pipeline, module, "infer vector layout") File "/usr/local/lib/python3.10/dist-packages/jax/_src/tpu_custom_call.py", line 264, in _run_pass_pipeline raise RuntimeError("\n".join(msg)) from None RuntimeError: Internal TPU kernel compiler error: unsupported operand shapes or layouts

The MLIR operation involved: %26 = "tpu.matmul"(%23, %25, %17) {transpose_lhs = false, transpose_rhs = true} : (vector<256x128xf32>, vector<16x128xf32>, vector<256x16xf32>) -> vector<256x16xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

Here is my code

import os
import functools
import time
import jax.experimental.pallas as pl
import jax
import jax.numpy as jnp

@jax.jit
def naive(q, k, v):
    temp = jnp.einsum('...hnd,...hmd->...hnm', q, k)
    # temp = flax.linen.softmax(temp)
    return jnp.einsum('...hnm,...hmd->...hnd', temp, v)

def mul_kernel(q_ref, k_ref, v_ref, out_ref, block_q, block_k, block_d):
    seq_length = k_ref.shape[0]
    q = q_ref[...]

    acc = jnp.zeros((block_q, block_d), dtype=q_ref.dtype)

    def loop_body(start_idx, carry):
        acc = carry
        k = pl.load(k_ref, (pl.dslice(start_idx * block_k, block_k), pl.dslice(0,block_d)))
        kv = pl.dot(q, k, trans_b=True,trans_a=False)
        v = pl.load(v_ref, (pl.dslice(start_idx * block_k, block_k), pl.dslice(None)))
        kv = kv @ v
        acc += kv
        return acc

    upper = seq_length // block_k
    acc = jax.lax.fori_loop(0, upper, loop_body, acc)
    # pl.store(out_ref, (pl.dslice(pid * block_q, block_q), pl.dslice(None)), q)
    out_ref[...] = acc

@jax.jit
def matrix_mul(q, k, v, block_q=256, block_k=16):
    print(q.shape)

    batch_size, heads, n, d = q.shape

    block_q = min(block_q, n)
    block_k = min(block_k, n)
    kernel = functools.partial(mul_kernel, block_q=block_q, block_k=block_k, block_d=d)
    return pl.pallas_call(
        kernel,
        interpret=False,
        grid=(n // block_q, batch_size, heads),
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
        in_specs=[
            pl.BlockSpec(lambda _, b, h: (b, h, _, 0), (None, None, block_q, d)),
            pl.BlockSpec(lambda _, b, h: (b, h, 0, 0), (None, None, n, d)),
            pl.BlockSpec(lambda _, b, h: (b, h, 0, 0), (None, None, n, d)),
        ],
        out_specs=pl.BlockSpec(lambda _, b, h: (b, h, _, 0), (None, None, block_q, d)),

        # dimension_semantics=("parallel", "parallel", "parallel", "parallel"),
        # mosaic_params=dict(
        #     dimension_semantics=("parallel", "parallel", "parallel", "parallel")
        # ),

    )(q, k, v)

# shape = (1, 8, 1024 * 16, 32)
shape = (1, 8, 256, 128)
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, shape, dtype=jnp.float32)

res = matrix_mul(x, x, x, )

What jax/jaxlib version are you using?

jax jaxlib 0.4.20

Which accelerator(s) are you using?

TPU

Additional system info

python3.10 Ubuntu20.04

NVIDIA GPU info

No response

apaszke commented 1 year ago

Sorry this escaped our attention! Yes, it is expected that the kernel does not work for small values of block_k. Anything below 128 will also be very inefficient on hardware, as it would have to be padded to 128 anyway. Is that an important use case for you, or did you just want to report the failure? I guess we should have some more checks before the kernel that raise a prettier error message.

demon2036 commented 1 year ago

Sorry this escaped our attention! Yes, it is expected that the kernel does not work for small values of block_k. Anything below 128 will also be very inefficient on hardware, as it would have to be padded to 128 anyway. Is that an important use case for you, or did you just want to report the failure? I guess we should have some more checks before the kernel that raise a prettier error message.

Thank you for your response. This is not an important use case for me, I just want to report this issue. However, I believe that in this situation, a warning should be raised to inform the user that this may not efficiently utilize the TPU, instead of outright rejecting execution with an error.