iree-org / iree-turbine

IREE's PyTorch Frontend, based on Torch Dynamo.
Apache License 2.0
55 stars 25 forks source link

[TKW] Implement MFMA F8 intrinsic/layout with interleaved reads based on VGPR #269

Open raikonenfnu opened 2 days ago

raikonenfnu commented 2 days ago

In order to enable Fp8 attention, we first need introduce the same 16x16x32xf8 intrinsic with k-Width=4. This is required S.T we can use it for the second gemm/contraction in attention so that the layouts between the gemms will match and we won't have to do a round trip to shared memory.

Since the native size of 16x16x32xf8 is kWidth=8, this new virtual intrinsic with k-Width=4, will need to do two interleaved reads of 4xf8 instead of a single 8xf8 read. This interleaved reads will looks something like:

 [0, 0, 0, 0, 16, 16, 16, 16, 32, 32, 32, 32, 48, 48, 48, 48, 0, 0, 0, 0, 16, 16, 16, 16, 32, 32, 32, 32, 48, 48, 48, 48]
 Where the entry represent the lane it is owned by. 

We can see that in this case, since lane0 needs data from index [0:4) and index[16:20), we'd need to do two separate reads (into two "VGPR"s) and combined them later on with insert_slices. Each of this contiguous chunk we'd consider/call VGPR chunks.

Below is a feature list that we need to implement to support the above use cases:

  1. Introduce a new virtual 16x16x32_F8 intrinsic/layout that has the strided/interleaved VGPR reads
  2. Spin out the VGPR splitting from partition_strided_operators into a standalone partition_ops_with_gpr_offsets
  3. Add support to handle Read(Originally just handled write) for partition_ops_with_gpr_offsets
  4. Implement canonicalization pattern to handle chained of extract slice in remove_chained_extractslice. This is to have cleaner code and conserve the structure of output IR that is impacted from the split of partition_strided_operators.
  5. Modify minimize_global_loads to also overwrite indices with GPR_OFFSET.
harsh-nod commented 1 day ago

Please rebase on top of master to get the builders to pass.