Facing this runtime error when using flash attention on nightly JAX.
File "/home/mohitkhatwani/maxtext/MaxText/train.py", line 285, in train_loop
state, metrics, nextrng = p_train_step(
File "/home/mohitkhatwani/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 95, in pallas_call_tpu_lowering_rule
return mlir.lower_fun(_lower_fun, multiple_results=True)(
File "/home/mohitkhatwani/.local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 82, in _lower_fun
return mosaic.as_tpu_kernel(
File "/home/mohitkhatwani/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 406, in as_tpu_kernel
lowered_module_asm, constants = _lower_tpu_kernel(
File "/home/mohitkhatwani/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 333, in _lower_tpu_kernel
_run_pass_pipeline(pipeline, module, "post-infer-vector-layout")
File "/home/mohitkhatwani/.local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py", line 260, in _run_pass_pipeline
raise RuntimeError("\n".join(msg)) from None
RuntimeError: Internal TPU kernel compiler error: Dynamic indices are not supported in the last two dimensions
The MLIR operation involved:
%33 = "vector.load"(%12, %9, %9, %32, %9) : (memref<1x1x512x256xbf16, #tpu.memory_space<vmem>>, index, index, index, index) -> vector<1x1x256x256xbf16>
... additional diagnostics were skipped.
What jax/jaxlib version are you using?
0.4.22.dev20231208+e686ed7e9, 0.4.22.dev20231208
Which accelerator(s) are you using?
TPU
Additional system info?
1.24.3 ,3.10.12, (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] uname_result(system='Linux', node='t1v-n-a573a025-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')
Description
Facing this runtime error when using flash attention on nightly JAX.
What jax/jaxlib version are you using?
0.4.22.dev20231208+e686ed7e9, 0.4.22.dev20231208
Which accelerator(s) are you using?
TPU
Additional system info?
1.24.3 ,3.10.12, (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] uname_result(system='Linux', node='t1v-n-a573a025-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')
NVIDIA GPU info
No response