jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.65k stars 2.82k forks source link

RuntimeError: Internal TPU kernel compiler error: Dynamic indices are not supported in the last two dimensions #18887

Open khatwanimohit opened 12 months ago

khatwanimohit commented 12 months ago

Description

Facing this runtime error when using flash attention on nightly JAX.

File "/home/mohitkhatwani/maxtext/MaxText/train.py", line 285, in train_loop
    state, metrics, nextrng = p_train_step(
  File "/home/mohitkhatwani/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 95, in pallas_call_tpu_lowering_rule
    return mlir.lower_fun(_lower_fun, multiple_results=True)(
  File "/home/mohitkhatwani/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 82, in _lower_fun
    return mosaic.as_tpu_kernel(
  File "/home/mohitkhatwani/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 406, in as_tpu_kernel
    lowered_module_asm, constants = _lower_tpu_kernel(
  File "/home/mohitkhatwani/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 333, in _lower_tpu_kernel
    _run_pass_pipeline(pipeline, module, "post-infer-vector-layout")
  File "/home/mohitkhatwani/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 260, in _run_pass_pipeline
    raise RuntimeError("\n".join(msg)) from None
RuntimeError: Internal TPU kernel compiler error: Dynamic indices are not supported in the last two dimensions

The MLIR operation involved:
  %33 = "vector.load"(%12, %9, %9, %32, %9) : (memref<1x1x512x256xbf16, #tpu.memory_space<vmem>>, index, index, index, index) -> vector<1x1x256x256xbf16>
... additional diagnostics were skipped.

What jax/jaxlib version are you using?

0.4.22.dev20231208+e686ed7e9, 0.4.22.dev20231208

Which accelerator(s) are you using?

TPU

Additional system info?

1.24.3 ,3.10.12, (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] uname_result(system='Linux', node='t1v-n-a573a025-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')

NVIDIA GPU info

No response

apaszke commented 10 months ago

Is that still happening? How can I reproduce this?