Open chengjunlu opened 2 months ago
@chengjunlu what about #950 ? That is prob. needed to reduce reg. pressure. We can track it here ?
@chengjunlu what about #950 ? That is prob. needed to reduce reg. pressure. We can track it here ?
Forgot that one. Yes, let's track it here too.
On agama 996 with latest main (commit https://github.com/intel/intel-xpu-backend-for-triton/commit/6f89dbecf78149f5af725455f81d09adf5db9004), Triton performance is 40% of XeTLA.
Note: TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1
is removed.
Measurements are done on PVC 1100.
Z | H | N_CTX | D_HEAD | CAUSAL | Triton-TFlops | XeTLA-TFlops | Triton/XeTLA |
---|---|---|---|---|---|---|---|
1 | 16 | 16384 | 128 | FALSE | 28.63831 | 80.13746 | 36% |
1 | 16 | 16384 | 128 | TRUE | 61.08398 | 156.2801 | 39% |
1 | 32 | 16384 | 64 | FALSE | 45.27705 | 103.8121 | 44% |
1 | 32 | 16384 | 64 | TRUE | 96.77028 | 195.821 | 49% |
2 | 16 | 8192 | 128 | FALSE | 28.07994 | 79.68677 | 35% |
2 | 16 | 8192 | 128 | TRUE | 57.68881 | 154.5856 | 37% |
2 | 32 | 8192 | 64 | FALSE | 42.40634 | 102.8581 | 41% |
2 | 32 | 8192 | 64 | TRUE | 88.55945 | 191.986 | 46% |
4 | 16 | 4096 | 128 | FALSE | 27.27483 | 78.8773 | 35% |
4 | 16 | 4096 | 128 | TRUE | 53.9501 | 151.3645 | 36% |
4 | 32 | 4096 | 64 | FALSE | 41.54896 | 97.96356 | 42% |
4 | 32 | 4096 | 64 | TRUE | 82.51222 | 184.9685 | 45% |
4 | 48 | 1024 | 64 | FALSE | 42.10615 | 80.2098 | 52% |
4 | 48 | 1024 | 64 | TRUE | 64.33444 | 127.5733 | 50% |
8 | 16 | 2048 | 128 | FALSE | 26.73036 | 78.13827 | 34% |
8 | 16 | 2048 | 128 | TRUE | 48.5033 | 145.088 | 33% |
8 | 32 | 2048 | 64 | FALSE | 40.18589 | 96.59209 | 42% |
8 | 32 | 2048 | 64 | TRUE | 73.87603 | 170.4972 | 43% |
16 | 16 | 1024 | 128 | FALSE | 26.08306 | 74.89045 | 35% |
16 | 16 | 1024 | 128 | TRUE | 40.21317 | 131.044 | 31% |
16 | 32 | 1024 | 64 | FALSE | 39.02387 | 90.7835 | 43% |
16 | 32 | 1024 | 64 | TRUE | 59.31046 | 149.4855 | 40% |
32 | 16 | 512 | 128 | FALSE | 24.72351 | 70.58867 | 35% |
32 | 16 | 512 | 128 | TRUE | 32.41974 | 104.963 | 31% |
32 | 32 | 512 | 64 | FALSE | 36.66681 | 80.65018 | 45% |
32 | 32 | 512 | 64 | TRUE | 45.07732 | 111.693 | 40% |
GEOMEAN | 40% |
Triton/XeTLA improves 40%->43% by removing all environment variables on agama 996.
CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/11221318041/job/31193000301 No performance improvement from agama 996, measured on 1550.
There are four new changes required for getting better performance of flash attention kernel on bottom-up optimization pipeline:
This issue is to track the new design required for flash-attention on bottom-up optimization pipeline.
Status
The most of the optimization passes has been finished and been checked in llvm-target branch. And all the tasks in the old issue #878 have been finished. The GEMM Triton kernel with block pointer syntax can get the 90% performance of the XeTLA version. There is a promising performance on the flash attention with block pointer by adding simply changes in RewriteBlockPointer pass.
New problem
There are two new problems found in the developing the bottom-up optimization pipeline:
tt.load
to support FP8 for flash attention.Plan
To achieve the goals of both performance and functionality on bottom-up phase, we need a new implementation than it is original planed.
tt.load
operation with the block pointer as memory ptr. (Optionally to support fallback to Intel 1D block IO.)This design also can benefit the new feature as TMA descriptor in future.