triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.5k stars 1.67k forks source link

Support computation pipelining after SWP refactoring #5185

Open manman-ren opened 1 week ago

manman-ren commented 1 week ago

With the recent SWP refactoring, it is much easier to support arbitrary stage assignments where computations can be separated into different stages. Computation pipelining is basically splitting computations to different stages. Take flash attention as an example: Currently the two loads are in stage 0 (S0), all other ops are in the last stage (stage 2). The loop body will look like MMA0(i) Softmax(i) MUL(i) MMA1(i) LoadV(i+2) LoadK(i+2)

This patch defines two different pipeline schedule for attention-like kernels: 1> putting first dot in S2, other computations in S3, loadK in stage 0, loadV in stage 1 MMA0(i+1)
Softmax(i) MUL(i) MMA1(i) loadK(i+3) loadV(i+2) 2> putting second dot in S3, other computations in S2, loadK in stage 0, loadV in stage 1 MMA0(i+1)
MMA1(i)
Softmax(i+1)
MUL(i+1)
loadK(i+3)
loadV(i+2)

Preliminary performance number on H100 for flash attention: (Batch, Heads, SeqLen, Dhead) triton_tutorial_flash_v2_opt-tflops triton_tutorial_flash_v2_tma-tflops triton_tutorial_flash_v2-tflops


         (8, 16, 8192, 128)                                517.528                                504.565                            481.402

The implementation and the frontend is preliminary for discussion.

manman-ren commented 1 week ago

@pawelszczerbuk The frontend is an annotation on loop, and inside the LoopSchedule pass, we are using the annotation to see if the ttgir matches with the specific schedule, if it does, we perform the corresponding <stage, cluster> assignment.

I understand that you are working on further refactoring and maybe frontend design for specifying a loop schedule. This PR is mostly to share the performance numbers and the preliminary implementation. Happy to work together on enabling this!