jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.3k stars 2.78k forks source link

INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast #24404

Open vanbasten23 opened 4 days ago

vanbasten23 commented 4 days ago

Description

Hi. I am extending the Pallas paged attention kernel. The case is a MQA. When I run my kernel, I encountered the following error which suggests it is an internal error and I should report here.

======================================================================
ERROR: test_extended_paged_attention_v1_multiple_queries (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 266, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast

at location: loc("/swap"(callsite("_flash_attention"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":188:0) at callsite("paged_flash_attention_kernel"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":331:0) at callsite("paged_attention"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":547:0) at callsite("test_extended_paged_attention_v1_multiple_queries"("/workspaces/persist/pytorch/xla/test/test_pallas.py":773:0) at "<module>"("/workspaces/persist/pytorch/xla/test/test_pallas.py":1669:0)))))))

The MLIR operation involved:
  %61 = "vector.shape_cast"(%60) : (vector<4x128xf32>) -> vector<1x4x1x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 773, in test_extended_paged_attention_v1_multiple_queries
    out = jax_extended_paged_attention1(
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 1669, in <module>
    test = unittest.main()
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 773, in test_extended_paged_attention_v1_multiple_queries
    out = jax_extended_paged_attention1(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 547, in paged_attention
    out = pl.pallas_call(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 331, in paged_flash_attention_kernel
    _flash_attention(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 188, in _flash_attention
    o_ref[:, q_head_idx, :] = acc_scratch_ref[:].astype(o_ref.dtype)
jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast

The MLIR operation involved:
  %61 = "vector.shape_cast"(%60) : (vector<4x128xf32>) -> vector<1x4x1x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

----------------------------------------------------------------------
Ran 1 test in 0.607s

FAILED (errors=1)

Here is my pallas kernel and the test code that calls the kernel.

Please let me know if you need more info.

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.33.dev20240913
jaxlib: 0.4.33.dev20240913
numpy:  2.1.1
python: 3.10.15 (main, Sep 27 2024, 06:06:16) [GCC 10.2.1 20210110]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(2,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(3,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-f3643994-w-0', release='5.19.0-1030-gcp', version='#32~22.04.1-Ubuntu SMP Thu Jul 13 09:36:23 UTC 2023', machine='x86_64')

cc: @miladm @WoosukKwon

vanbasten23 commented 2 days ago

The problematic line is o_ref[:, q_head_idx, :] = acc_scratch_ref[:].astype(o_ref.dtype). I found a way to work around the problem (the code is in https://github.com/jax-ml/jax/issues/24415). But I'm trying to figure out why the flash attention example also does something similar but it works fine.

vanbasten23 commented 2 days ago

It seems the assignee is not set when I use the link https://github.com/google/jax/issues/new?assignees=apaszke in the error message to create the issue. So manually cc @apaszke

justinjfu commented 1 day ago

https://github.com/jax-ml/jax/pull/22938 should in principle address this, which was checked in on Sep 20 (that's newer than the version you are running).

For some explanation on the error: The last two dimensions of an array are special because they are physically tiled into VREGs (also the reason for the special 8x128 block size as noted here: https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#blockspecs-and-grid-iteration). So certain reshapes require additional work under the hood.

Because of the tiling, it's in general more efficient to leave the singleton dimensions in front rather than in the last 2 dimensions if you can afford to do so. For example, reshaping from 4x128 to 4x1x128 would require 4 copy operations to copy each row of the original VREG into the first row of 4 new VREGs. Whereas reshaping from 4x128 to 1x4x128 is effectively "free" since it just adds an extra logical dimension in the front that can be handled at compile time.

vanbasten23 commented 1 day ago

Thanks Justin for the explanation!