Closed whitneywhtsang closed 1 week ago
Huge SBIDStalls in Triton kernel. Also more DistStalls in Triton kernel(the red line) WIP to investigate and finetune.
Status: Did lots of finetunes and experiments. Most finetune configs are harmful to the performance. Modified llir by hand and execute it, there is also no big performance gains(ex. While dpas insts are atomic, we'll have more spills. It will offset the benefit from atomic dpas). The reason is higher SBIDStall and DistStall for load/sync insts than XETLA. This shows we are in short of registers. We may need helps from XETLA team or others.
Cache hit of XETLA: Cache hit of Triton: Triton with PR:
Overall status:
According to the new overall status:
extra N_CTX=512 shapes:
Still need some works to bring 92.99% up to 95%+.
What did we change to make the performance improvement?
What did we change to make the performance improvement?
I change the grid info (global_range in SYCL kernel submit). The main idea is to keep the splited-M axis align with the num_warps * threads_per_warp
axis. For 32, 32, 512, 64, False
case, the range will change from {{32, 32, 4}, {128, 1, 1}}
to {{4, 32, 32}, {128, 1, 1}}
(in CUDA style), which 4 = N_CTX / BLOCK_M = 512 / 128. This change will benefit our L3 cache hit greatly(about 10x better) for N_CTX=512 cases, but is harmful to some other cases(especially causal=True, small batch cases observed). The detail mechnism is under investigation.
There are currently 3 shapes (causal=false, d_head=64) that have performance <95% of XeTLA.