triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.82k stars 1.55k forks source link

Better Scheduling/Rematerialization at MIR level for NVGPU? #3557

Open manman-ren opened 5 months ago

manman-ren commented 5 months ago

We noticed that rematerialization of address calculation can sometimes change the kernel performance by more than 5%. It is hard to keep the remat persistent through the llvm passes. Even though nvcc does scheduling and likely remat to reduce register pressure, there are some patterns that may be specific to Triton generated llir, which can benefit from a scheduling pass at PTX level. As another example (https://triton-lang.slack.com/archives/C06EJRSF7CN/p1707938806061369), we can potentially do fine grain scheduling with tensor core. CC @bertmaher

Another option is to see if there is any flag/parameter within nvcc that we can tweak to achieve a better schedule for the above two cases. CC @joker-eph

We can start by enabling machine scheduling for PTX in llvm and try to do simple address calculation remat there. The pass can be enabled/disable from Triton and it can have tunable parameters so we can tune them for performance critical kernels.

Feedback is highly appreciated! CC @ThomasRaoux

jlebar commented 5 months ago

ptxas is an optimizing assembler, and our experiments indicate that it does scheduling and CSE itself. This, plus the fact that ptxas also does register allocation, makes this very difficult.

If you decide to go down this path, I strongly encourage you to look at optimized SASS, not the PTX. But also, @pawelszczerbuk looked at this extensively and came up empty-handed, so I would recommend focusing on other things.

ThomasRaoux commented 5 months ago

In my experience scheduling pre-ptxas for some high pressure and high latency loops (like matmul loop) can have a important impact on perf. Also note that currently LLVM vectorizer moves around instructions more aggressively than it should and effectively does some scheduling helping us by luck but since it is not vectorizing with the right vector length it is actually something we should disabled. That being said scheduling is NP problem and doing it in a generic way is surely an extremely difficult problem. One thing that could be experimented is creating specific strategy that are known to work well for some kernels and have triton pick between them (or maybe even let user pick).

Disabling LLVM vectorizer or changing the scheduling strategy in LLVM to do pressure vs latency schedule has a significant effect on perf from what I have measured in the past, this kind of proves that it has an impact (even if getting it right is hard as mentioned)

Rematerialization is a trickier problem and as mentioned it is not obvious there is a solution.

manman-ren commented 4 months ago

I recently hit a related perf issue on a PT2 benchmark. Disabling LSR gets 11% perf win. The changes in ptx are related to scalar operations associated with address calculations. I checked register liveness at ptx level, with and without LSR, the differences are not as big as the differences at SASS level (230 vs. 255 registers).

In https://github.com/openai/triton/pull/3732, I am extending DISABLE_LLVM_OPT to pass in a list of flags to disable llvm optimizations. Longer term, I plan to add some basic classes for specifying a NVPTX sched strategy. And it will make adding new strategies easier.

Disabling LLVM vectorizer or changing the scheduling strategy in LLVM to do pressure vs latency schedule has a significant effect on perf from what I have measured in the past

Do you mean changing the existing scheduling strategy in LLVM? Can you share the command line flags?

Kind of related, I saw a bunch of commits about _experimental_descriptor_load from @ThomasRaoux and we are interested in testing it out on our kernels. Is it in a reasonable state for us to test for performance? I understand it may be taken out later. Using TMA should reduce address calculations, decreasing register pressure due to scalar ops.

ThomasRaoux commented 4 months ago

with and without LSR, the differences are not as big as the differences at SASS level (230 vs. 255 registers).

255 is the max number of registers, note that high pressure may also prevent aggressive scheduling so even if there are the same number of registers used at the end the code may be slower.

Do you mean changing the existing scheduling strategy in LLVM? Can you share the command line flags?

There is an option -nvptx-sched4reg to use the DAG selection scheduler to reduce pressure.

Kind of related, I saw a bunch of commits about _experimental_descriptor_load from @ThomasRaoux and we are interested in testing it out on our kernels. Is it in a reasonable state for us to test for performance?

Not yet, there are few more changes needed to get to reasonable perf. I'm currently working on those.