In order to enable Fp8 attention, we first need introduce the same 16x16x32xf8 intrinsic with k-Width=4. This is required S.T we can use it for the second gemm/contraction in attention so that the layouts between the gemms will match and we won't have to do a round trip to shared memory.
Since the native size of 16x16x32xf8 is kWidth=8, this new virtual intrinsic with k-Width=4, will need to do two interleaved reads of 4xf8 instead of a single 8xf8 read. This interleaved reads will looks something like:
[0, 0, 0, 0, 16, 16, 16, 16, 32, 32, 32, 32, 48, 48, 48, 48, 0, 0, 0, 0, 16, 16, 16, 16, 32, 32, 32, 32, 48, 48, 48, 48]
Where the entry represent the lane it is owned by.
We can see that in this case, since lane0 needs data from index [0:4) and index[16:20), we'd need to do two separate reads (into two "VGPR"s) and combined them later on with insert_slices. Each of this contiguous chunk we'd consider/call
VGPR chunks.
Below is a feature list that we need to implement to support the above use cases:
Introduce a new virtual 16x16x32_F8 intrinsic/layout that has the strided/interleaved VGPR reads
Spin out the VGPR splitting from partition_strided_operators into a standalone partition_ops_with_gpr_offsets
Add support to handle Read(Originally just handled write) for partition_ops_with_gpr_offsets
Implement canonicalization pattern to handle chained of extract slice in remove_chained_extractslice. This is to have cleaner code and conserve the structure of output IR that is impacted from the split of partition_strided_operators.
Modify minimize_global_loads to also overwrite indices with GPR_OFFSET.
In order to enable Fp8 attention, we first need introduce the same 16x16x32xf8 intrinsic with k-Width=4. This is required S.T we can use it for the second gemm/contraction in attention so that the layouts between the gemms will match and we won't have to do a round trip to shared memory.
Since the native size of 16x16x32xf8 is kWidth=8, this new virtual intrinsic with k-Width=4, will need to do two interleaved reads of 4xf8 instead of a single 8xf8 read. This interleaved reads will looks something like:
We can see that in this case, since lane0 needs data from index [0:4) and index[16:20), we'd need to do two separate reads (into two "VGPR"s) and combined them later on with insert_slices. Each of this contiguous chunk we'd consider/call VGPR chunks.
Below is a feature list that we need to implement to support the above use cases:
partition_strided_operators
into a standalonepartition_ops_with_gpr_offsets
partition_ops_with_gpr_offsets
remove_chained_extractslice
. This is to have cleaner code and conserve the structure of output IR that is impacted from the split ofpartition_strided_operators
.minimize_global_loads
to also overwrite indices with GPR_OFFSET.