Open alexbaden opened 2 days ago
@arunjose696 The idea of the algorithm is as follows:
Look through each basic block of the function to find one that starts with a PhiNode.
When we find a basic block that starts with a PhiNode, process that basic block by first checking to see if any of the PhiNode values are null/0 constants.
If no PhiNode values are null/zero, no further action is needed. If we have a null or zero, then we iterate the instructions in the BB to see if any sdiv/srem instructions use that null/zero value. If so, we freeze the output of the PhiNode and replace the operand in the sdiv/srem instruction with that frozen value.
The first loop only looks at the first instruction, but iterating all the instructions and breaking is a relatively easy way to do this (and is done in many other LLVM passes). The second loop has to look at all instructions in the BB.
I'd rather have a lit test than the current test. But I'm open to having both.
I can work on a lit test, but the regression test is far more important as the concern is keeping the mask false path intact throughout the LLVM optimization pipeline.
Close #2726
From the code comments:
The Triton masked load pattern can generate instances where the mask value causes undefined behavior in sdiv/srem instructions. The language allows this UB as the result of those arithmetic instructions is never used, and control flow to avoid computation of these instructions would negatively affect performance. But, LLVM SimplifyCFG aggressively marks code paths with undefined behavior as dead. This can result in removal of the mask path and incorrect results from legal Triton kernels due to masked elements being used in computation. Run a pass to add a freeze instruction between masked loads and sdiv/srem to signal to LLVM we consider the sdiv/srem operands to be well defined.
The strategy here is to basically invalidate the assumptions under which SimplifyCFG can remove UB for sdiv/srem. The rationale is that, unlike C/C++, Triton explicitly allows UB in sdiv/srem instructions (likely because the hardware Triton is targeting allows that). Inserting a
freeze
instruction both signals that we expect the behavior of sdiv/srem to be well defined and hides the constant 0 in the phi from SimplifyCFG's UB optimizations.The pass needs to run after every instance of
InstCombine
because the LLVM optimization that removes UB only occurs if the sdiv/srem are in the same BB as the phi, which can happen after anyInstCombine
.Note that the directory structure for this pass is a little different than
BreakStructPhiNodesPass
because we are already using those directories inthird_party
for MLIR code. If we want to change that, I can open an issue but let's do it separately from this PR.