triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.47k stars 1.66k forks source link

[BACKEND] Fix ProgramPoint passing in AxisInfoAnalysis #5181

Closed aakhundov closed 4 days ago

aakhundov commented 5 days ago

Fixes #5122.

The ProgramPoint here is created on the stack. Then its address is passed to the MLIR SparseAnalysis code, where it is added as a dependency and later dereferenced. By the time the ProramPoint is dereferenced in the AbstractSparseForwardDataFlowAnalysis::visit, the AxisInfoAnalysis::visitForOpInductionVar will have finished and the ProgramPoint stack variable destroyed. This leads to a segfault (which can be reproed on the base rev with the lit test added in this PR).

The code modified in this PR was originally added in #4927, in conjunction with updating the llvm-project hash to b5cc222d7429. However, as noted in https://github.com/llvm/llvm-project/pull/110344 (the llvm-project PR that has made the refactoring prompting the AxisInfo.cpp change in #4927):

For dense forward data-flow analysis and other analysis (except dense backward data-flow analysis), the program point corresponding to the original operation can be obtained by getProgramPointAfter(op)

As the AxisInfoAnalysis (in Triton) inherits from SparseForwardDataFlowAnalysis (in MLIR), in this PR we follow the above which resolves the segfault issue (as the ProgramPoint is now stored in the instance-level state of the pass).

P.S. The lit test added in this PR is not exactly minimal. However, I did my best to minimize it starting from the 400-line repro TTGIR in #5122. Further minimization does not seem to expose the segfault.

aakhundov commented 5 days ago

@ThomasRaoux

should it be:

getLatticeElementFor(op, op.getLowerBound())->getValue();

IIUC, getLatticeElementFor here takes ProgramPoint* as the first argument (used to be ProgramPoint before https://github.com/llvm/llvm-project/pull/110344) which is not directly initializable from scf::ForOp. There is also getLatticeElement here which just takes a Value and returns the same thing as getLatticeElementFor. Do you mean we can use getLatticeElement here instead of getLatticeElementFor as in the original code (my impression was that adding a dependency is intended)?

ThomasRaoux commented 4 days ago

@ThomasRaoux

should it be:

getLatticeElementFor(op, op.getLowerBound())->getValue();

IIUC, getLatticeElementFor here takes ProgramPoint* as the first argument (used to be ProgramPoint before llvm/llvm-project#110344) which is not directly initializable from scf::ForOp. There is also getLatticeElement here which just takes a Value and returns the same thing as getLatticeElementFor. Do you mean we can use getLatticeElement here instead of getLatticeElementFor as in the original code (my impression was that adding a dependency is intended)?

You're right, looks like I was looking at some old MLIR version where there was implicit casting from Operation to ProgramPoint. The code looks right

bertmaher commented 4 days ago

Cherry-picked to rc/3.2.x, thanks for the fix @aakhundov