Open demon2036 opened 1 year ago
Sorry this escaped our attention! Yes, it is expected that the kernel does not work for small values of block_k. Anything below 128 will also be very inefficient on hardware, as it would have to be padded to 128 anyway. Is that an important use case for you, or did you just want to report the failure? I guess we should have some more checks before the kernel that raise a prettier error message.
Sorry this escaped our attention! Yes, it is expected that the kernel does not work for small values of block_k. Anything below 128 will also be very inefficient on hardware, as it would have to be padded to 128 anyway. Is that an important use case for you, or did you just want to report the failure? I guess we should have some more checks before the kernel that raise a prettier error message.
Thank you for your response. This is not an important use case for me, I just want to report this issue. However, I believe that in this situation, a warning should be raised to inform the user that this may not efficiently utilize the TPU, instead of outright rejecting execution with an error.
Description
When I use TPU, I attempt to implement tiling operation using Pallas, where q is tiled as (block_q, block_d) and v is tiled as (block_k, block_d). I use pl.dot(q, k, trans_b=True) to achieve q@k.T; however, I've noticed that if block_k is less than 128, it results in an error. In my case, q is a vector <256x128xf32> and k is a vector <16x128xf32> and the result be vector<256x16xf32>
Error Information
Traceback (most recent call last): File "/root/test_multihead.py", line 101, in
res = matrix_mul(x, x, x, )
File "/root/test_multihead.py", line 73, in matrix_mul
return pl.pallas_call(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pallas/pallas_call.py", line 383, in wrapped
out_flat = pallas_call_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: Internal TPU kernel compiler error: unsupported operand shapes or layouts
The MLIR operation involved: %26 = "tpu.matmul"(%23, %25, %17) {transpose_lhs = false, transpose_rhs = true} : (vector<256x128xf32>, vector<16x128xf32>, vector<256x16xf32>) -> vector<256x16xf32>
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/root/test_multihead.py", line 101, in
res = matrix_mul(x, x, x, )
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 88, in pallas_call_tpu_lowering_rule
return mlir.lower_fun(_lower_fun, multiple_results=True)(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 79, in _lower_fun
return mosaic.as_tpu_kernel(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/tpu_custom_call.py", line 408, in as_tpu_kernel
lowered_module_asm, constants = _lower_tpu_kernel(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/tpu_custom_call.py", line 335, in _lower_tpu_kernel
_run_pass_pipeline(pipeline, module, "infer vector layout")
File "/usr/local/lib/python3.10/dist-packages/jax/_src/tpu_custom_call.py", line 264, in _run_pass_pipeline
raise RuntimeError("\n".join(msg)) from None
RuntimeError: Internal TPU kernel compiler error: unsupported operand shapes or layouts
The MLIR operation involved: %26 = "tpu.matmul"(%23, %25, %17) {transpose_lhs = false, transpose_rhs = true} : (vector<256x128xf32>, vector<16x128xf32>, vector<256x16xf32>) -> vector<256x16xf32>
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
Here is my code
What jax/jaxlib version are you using?
jax jaxlib 0.4.20
Which accelerator(s) are you using?
TPU
Additional system info
python3.10 Ubuntu20.04
NVIDIA GPU info
No response