Closed xinyazhang closed 3 months ago
The following passes do not exist in the newer code https://github.com/ROCm/triton/blob/9b73a543a5545960bcaf2830900b0560eec443c5/lib/Target/LLVMIR/LLVMIRTranslation.cpp#L481-L484
Maybe we can try to add them to the make_llir function to see if it can fix the problem.
We have those two passes in upstream
The hanging happens in the add_builtin_func_to_llvmir pass: https://github.com/triton-lang/triton/blob/fa2271e37f4e0ccfa8829501e40533060937cfe5/third_party/amd/backend/compiler.py#L185
I am working on this, because it looks (very) slightly simpler than https://github.com/ROCm/triton-internal/issues/104
This is what I got so far:
convert-builtin-func-to-llvm
bug.mlir
which is the file just before convert-builtin-func-to-llvm
(attached to this comment) and I commented out all the loads and some of the stores. While triton-opt
terminates, the output produced is massively large. The more stores we add back into bug.mlir
, the more time it takes to complete (I think that if we leave it long enough it will eventually complete)mergeIdenticalBlocks
transformation contained in the simplifyRegion
utility. If I disable that transformation enableRegionSimplify=false
then compilation is quite quick.disable_simplify.mlir
output that comes from bug.mlir
when passing enableRegionSimplify=false
to the rewriter. If we do: triton-opt --canonicalize disable_simplify.mlir
we see the same masssive output as befreo with triton-opt
that takes some time to finish. Instead, if we do: triton-opt --canonicalize="region-simplify=false" disable_simplify.mlir
the output is normal, and triton-opt
terminates quickly. Another note that might help On the repro script, if I break the nested if below on line 367 (by just deleting the else there), then it doesn't hang. Perhaps this is related to stores in nested-if statements.
if q_padded:
if PADDED_HEAD:
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
else:
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,))
else:
if PADDED_HEAD:
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(1,))
else:
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
I guess the comment here bascially confirms that
# This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
# count caused by predicated loads/stores. In certain kernels, the addition of these blocks can cause the MLIR
# canonicalizer to never finish when attempting to merge blocks. The permanent solution under consideration
# involves using MUBUF instructions that have built-in out-of-bounds checks, which would eliminate the need
# for conditional branching around memory accesses.
So yes, I was aware of this comment, but @antiagainst was asking if there could be a simpler solution than implementing buffer loads. I guess the main question is :is this a bug or is this unavoidable because of so many blocks?
Should you, me and @antiagainst have a chat to decide the best step forward?
So this is the situation we have in the CFG (cc @antiagainst ):
I think the problem is that MLIR is trying to produce a single big if
-block where to put all those subgraphs
So, I think all in all this is a correct transformation, also in our case. What happens is that we meet the following cases:
leader block:
^bb152: // pred: ^bb151
llvm.store %3316, %3245 : i16, !llvm.ptr<1>
llvm.br ^bb153
blocks to merge:
^bb181: // pred: ^bb180
"llvm.store"(%3409, %3375) <{ordering = 0 : i64}> : (i16, !llvm.ptr<1>) -> ()
"llvm.br"()[^bb153] : () -> ()
In this case those blocks can be merged, and the merged block will have +2 operands
^bb151: // 2 preds: ^bb149, ^bb150
%3315 = llvm.insertelement %3148, %3251[%60 : i32] : vector<1xf16>
%3316 = llvm.bitcast %3315 : vector<1xf16> to i16
llvm.cond_br %3254, ^bb152(%3316, %3245 : i16, !llvm.ptr<1>), ^bb153
blocks to merge:
^bb180: // 2 preds: ^bb178, ^bb179
%3412 = "llvm.insertelement"(%3385, %3336, %60) : (vector<1xf16>, f16, i32) -> vector<1xf16>
%3413 = "llvm.bitcast"(%3412) : (vector<1xf16>) -> i16
"llvm.cond_br"(%3384, %3413, %3379)[^bb152, ^bb153] <{operandSegmentSizes = array<i32: 1, 2, 0>}> : (i1, i16, !llvm.ptr<1>) -> ()
In this case the blocks are still structurally similar, but we are doubling the number of input operands of the merged block. When we do that 64 times, we get to blocks that have 32764
input operands which is very slow to handle.
We can introduce a threshold: don't merge the blocks if this results in more than K
(defaulted to 16?) input operands in the resulting block
A threshold is a good idea, I think. Where would we implement the threshold? In the canonicalizer?
Yes, we can have an option like maxBlockArguments
in the canonicalizer pass defaulted to 16. I tried to hardcode that and indeed it works fine. I will try to update a patch.
I want also underline that by not merging those blocks we are creating a super branchy code that will probably be very slow. So once I implement this, I will try to finish the buffer_load implementation
Yup agreed that having a threshold in the greedy pattern rewriter configuration to control this would be good. Once you have the patch to mlir please add me as a reviewer.
They were faster than me :) : https://github.com/llvm/llvm-project/pull/95057
Not sure if the threshold solution is better or not, but I commented on the PR instead of creating a different one
(note, once the PR is merged, we should upgrade Triton commit to get the change)
@giuseros Have you verified the upstream PR will address the two use cases? This ticket and https://github.com/ROCm/triton-internal/issues/104
Yes, it disables block merging on canonicalization that is the root cause of both.
Update on this: they made a further change (or the change was there and it skipped my eye) for which they now enable block-merging in the rewriter. If they stick with that, we will have the hang (see https://github.com/llvm/llvm-project/pull/95057#discussion_r1636389533)
Either we convince them to disable merging into the rewriter, or I will have (urgently) to implement this:
After thinking about this, I guess we can set:
GreedySimplifyRegionLevel enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
When we instantiate the rewriter. And meanwhile I can work on https://github.com/llvm/llvm-project/issues/63230 to solve the core issue.
I was trying to follow you discussion with Mehdi on that upstream PR. What does it mean by "they disable block merging for canonicalization but enable it for rewriter"?
Both the canonicalize pass and the rewriter use the simplifyRegions
function. The solution Mehdi is proposing is to default to simplifyRegions(normal)
in the canonicalize pass (block merging disabled) and simplifyRegions(aggressive)
in the rewriter (block merging enabled -> hang). We can change the default behaviour of the rewriter in Triton (so that it calls simplifyRegions(normal)
, but this means setting passing a config
every time we invoke it (with config.enableRegionSimplification =Normal
)
Does the rewriter call simplifyRegions (and probably other passes to canonicalize stuff) after it matches and rewrites all the ops?
We can change the default behaviour of the rewriter in Triton (so that it calls simplifyRegions(normal), but this means setting passing a config every time we invoke it (with config.enableRegionSimplification =Normal )
We only need to set it for the rewriter in builtin_func_to_llvm
pass. right? If so, are there any other drawbacks ?
Does the rewriter call simplifyRegions (and probably other passes to canonicalize stuff) after it matches and rewrites all the ops?
Yes
We only need to set it for the rewriter in builtin_func_to_llvm pass. right? If so, are there any other drawbacks ?
And anytime we invoke the rewriter after that. I see that builtin_func_to_llvm is the last pass, so it shouldn't be an issue.
Of course there is the core drawback that we will disable block merging in all cases. But this is something we can worry later (and I will try to work on it in my "spare" time)
Problem Description
Full source code to reproduce: rep.py.gz
Triton version: upstream
d688063f731cfc4d9431bb8c0d0d73dce8cd1c38
Docker Container:rocm/pytorch-private:compute-rocm-rel-6.1-116_ubuntu22.04_py3.9_pytorch_rocm6.1_internal_testing_ae01701
Can be reproduced in both MI200(gfx90a) and Navi3x. Debugging print shows the compiler hangs during ttgir->llir stage.
Operating System
Ubuntu 22.04.4 LTS (Jammy Jellyfish)
CPU
AMD Ryzen Threadripper PRO 5975WX 32-Cores
GPU
AMD Instinct MI210
ROCm Version
ROCm 6.1.0
ROCm Component
No response
Steps to Reproduce
Download the rep.py.gz in the Description section, and then
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response