Open vanbasten23 opened 4 days ago
The problematic line is o_ref[:, q_head_idx, :] = acc_scratch_ref[:].astype(o_ref.dtype)
. I found a way to work around the problem (the code is in https://github.com/jax-ml/jax/issues/24415). But I'm trying to figure out why the flash attention example also does something similar but it works fine.
It seems the assignee is not set when I use the link https://github.com/google/jax/issues/new?assignees=apaszke in the error message to create the issue. So manually cc @apaszke
https://github.com/jax-ml/jax/pull/22938 should in principle address this, which was checked in on Sep 20 (that's newer than the version you are running).
For some explanation on the error: The last two dimensions of an array are special because they are physically tiled into VREGs (also the reason for the special 8x128 block size as noted here: https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#blockspecs-and-grid-iteration). So certain reshapes require additional work under the hood.
Because of the tiling, it's in general more efficient to leave the singleton dimensions in front rather than in the last 2 dimensions if you can afford to do so. For example, reshaping from 4x128 to 4x1x128 would require 4 copy operations to copy each row of the original VREG into the first row of 4 new VREGs. Whereas reshaping from 4x128 to 1x4x128 is effectively "free" since it just adds an extra logical dimension in the front that can be handled at compile time.
Thanks Justin for the explanation!
Description
Hi. I am extending the Pallas paged attention kernel. The case is a MQA. When I run my kernel, I encountered the following error which suggests it is an internal error and I should report here.
Here is my pallas kernel and the test code that calls the kernel.
Please let me know if you need more info.
System info (python version, jaxlib version, accelerator, etc.)
cc: @miladm @WoosukKwon