ROCm / triton

Development repository for the Triton language and compiler
MIT License
89 stars 27 forks source link

[Issue]: Triton Compiler Takes Indefinite Time in ttgir -> llir Stage. #596

Closed xinyazhang closed 3 months ago

xinyazhang commented 4 months ago

Problem Description

Full source code to reproduce: rep.py.gz

Triton version: upstream d688063f731cfc4d9431bb8c0d0d73dce8cd1c38 Docker Container: rocm/pytorch-private:compute-rocm-rel-6.1-116_ubuntu22.04_py3.9_pytorch_rocm6.1_internal_testing_ae01701

Can be reproduced in both MI200(gfx90a) and Navi3x. Debugging print shows the compiler hangs during ttgir->llir stage.

Operating System

Ubuntu 22.04.4 LTS (Jammy Jellyfish)

CPU

AMD Ryzen Threadripper PRO 5975WX 32-Cores

GPU

AMD Instinct MI210

ROCm Version

ROCm 6.1.0

ROCm Component

No response

Steps to Reproduce

Download the rep.py.gz in the Description section, and then

gunzip rep.py.gz
python rep.py

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

xinyazhang commented 4 months ago

The following passes do not exist in the newer code https://github.com/ROCm/triton/blob/9b73a543a5545960bcaf2830900b0560eec443c5/lib/Target/LLVMIR/LLVMIRTranslation.cpp#L481-L484

Maybe we can try to add them to the make_llir function to see if it can fix the problem.

zhanglx13 commented 3 months ago

Update 06/05/2024

We have those two passes in upstream

The hanging happens in the add_builtin_func_to_llvmir pass: https://github.com/triton-lang/triton/blob/fa2271e37f4e0ccfa8829501e40533060937cfe5/third_party/amd/backend/compiler.py#L185

giuseros commented 3 months ago

I am working on this, because it looks (very) slightly simpler than https://github.com/ROCm/triton-internal/issues/104

This is what I got so far:

repro.zip

jayfurmanek commented 3 months ago

Another note that might help On the repro script, if I break the nested if below on line 367 (by just deleting the else there), then it doesn't hang. Perhaps this is related to stores in nested-if statements.

    if q_padded:
        if PADDED_HEAD:
            tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
        else:
            tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,))
    else:
        if PADDED_HEAD:
            tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(1,))
        else:
            tl.store(O_block_ptr, acc.to(Out.type.element_ty))
jayfurmanek commented 3 months ago

I guess the comment here bascially confirms that

  # This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
        # count caused by predicated loads/stores. In certain kernels, the addition of these blocks can cause the MLIR
        # canonicalizer to never finish when attempting to merge blocks. The permanent solution under consideration
        # involves using MUBUF instructions that have built-in out-of-bounds checks, which would eliminate the need
        # for conditional branching around memory accesses.
giuseros commented 3 months ago

So yes, I was aware of this comment, but @antiagainst was asking if there could be a simpler solution than implementing buffer loads. I guess the main question is :is this a bug or is this unavoidable because of so many blocks?

Should you, me and @antiagainst have a chat to decide the best step forward?

giuseros commented 3 months ago

So this is the situation we have in the CFG (cc @antiagainst ): image

I think the problem is that MLIR is trying to produce a single big if-block where to put all those subgraphs

giuseros commented 3 months ago

So, I think all in all this is a correct transformation, also in our case. What happens is that we meet the following cases:

Store case

leader block:
^bb152:  // pred: ^bb151
  llvm.store %3316, %3245 : i16, !llvm.ptr<1>
  llvm.br ^bb153
blocks to merge:
^bb181:  // pred: ^bb180
  "llvm.store"(%3409, %3375) <{ordering = 0 : i64}> : (i16, !llvm.ptr<1>) -> ()
  "llvm.br"()[^bb153] : () -> ()

In this case those blocks can be merged, and the merged block will have +2 operands

Insertelement case

^bb151:  // 2 preds: ^bb149, ^bb150
  %3315 = llvm.insertelement %3148, %3251[%60 : i32] : vector<1xf16>
  %3316 = llvm.bitcast %3315 : vector<1xf16> to i16
  llvm.cond_br %3254, ^bb152(%3316, %3245 : i16, !llvm.ptr<1>), ^bb153
blocks to merge:
^bb180:  // 2 preds: ^bb178, ^bb179
  %3412 = "llvm.insertelement"(%3385, %3336, %60) : (vector<1xf16>, f16, i32) -> vector<1xf16>
  %3413 = "llvm.bitcast"(%3412) : (vector<1xf16>) -> i16
  "llvm.cond_br"(%3384, %3413, %3379)[^bb152, ^bb153] <{operandSegmentSizes = array<i32: 1, 2, 0>}> : (i1, i16, !llvm.ptr<1>) -> ()

In this case the blocks are still structurally similar, but we are doubling the number of input operands of the merged block. When we do that 64 times, we get to blocks that have 32764 input operands which is very slow to handle.

Possible (quick) workaround

We can introduce a threshold: don't merge the blocks if this results in more than K (defaulted to 16?) input operands in the resulting block

jayfurmanek commented 3 months ago

A threshold is a good idea, I think. Where would we implement the threshold? In the canonicalizer?

giuseros commented 3 months ago

Yes, we can have an option like maxBlockArguments in the canonicalizer pass defaulted to 16. I tried to hardcode that and indeed it works fine. I will try to update a patch.

I want also underline that by not merging those blocks we are creating a super branchy code that will probably be very slow. So once I implement this, I will try to finish the buffer_load implementation

antiagainst commented 3 months ago

Yup agreed that having a threshold in the greedy pattern rewriter configuration to control this would be good. Once you have the patch to mlir please add me as a reviewer.

giuseros commented 3 months ago

They were faster than me :) : https://github.com/llvm/llvm-project/pull/95057

Not sure if the threshold solution is better or not, but I commented on the PR instead of creating a different one

(note, once the PR is merged, we should upgrade Triton commit to get the change)

jerryyin commented 3 months ago

@giuseros Have you verified the upstream PR will address the two use cases? This ticket and https://github.com/ROCm/triton-internal/issues/104

giuseros commented 3 months ago

Yes, it disables block merging on canonicalization that is the root cause of both.

giuseros commented 3 months ago

Update on this: they made a further change (or the change was there and it skipped my eye) for which they now enable block-merging in the rewriter. If they stick with that, we will have the hang (see https://github.com/llvm/llvm-project/pull/95057#discussion_r1636389533)

Either we convince them to disable merging into the rewriter, or I will have (urgently) to implement this:

giuseros commented 3 months ago

After thinking about this, I guess we can set:

 GreedySimplifyRegionLevel enableRegionSimplification =      GreedySimplifyRegionLevel::Normal;

When we instantiate the rewriter. And meanwhile I can work on https://github.com/llvm/llvm-project/issues/63230 to solve the core issue.

zhanglx13 commented 3 months ago

I was trying to follow you discussion with Mehdi on that upstream PR. What does it mean by "they disable block merging for canonicalization but enable it for rewriter"?

giuseros commented 3 months ago

Both the canonicalize pass and the rewriter use the simplifyRegions function. The solution Mehdi is proposing is to default to simplifyRegions(normal) in the canonicalize pass (block merging disabled) and simplifyRegions(aggressive) in the rewriter (block merging enabled -> hang). We can change the default behaviour of the rewriter in Triton (so that it calls simplifyRegions(normal), but this means setting passing a config every time we invoke it (with config.enableRegionSimplification =Normal )

zhanglx13 commented 3 months ago

Does the rewriter call simplifyRegions (and probably other passes to canonicalize stuff) after it matches and rewrites all the ops?

We can change the default behaviour of the rewriter in Triton (so that it calls simplifyRegions(normal), but this means setting passing a config every time we invoke it (with config.enableRegionSimplification =Normal )

We only need to set it for the rewriter in builtin_func_to_llvm pass. right? If so, are there any other drawbacks ?

giuseros commented 3 months ago

Does the rewriter call simplifyRegions (and probably other passes to canonicalize stuff) after it matches and rewrites all the ops?

Yes

We only need to set it for the rewriter in builtin_func_to_llvm pass. right? If so, are there any other drawbacks ?

And anytime we invoke the rewriter after that. I see that builtin_func_to_llvm is the last pass, so it shouldn't be an issue.

Of course there is the core drawback that we will disable block merging in all cases. But this is something we can worry later (and I will try to work on it in my "spare" time)