microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
165 stars 34 forks source link

[Potential Bug]: sanity check addState in handling modulos #80

Closed haishanzzzz closed 9 months ago

haishanzzzz commented 9 months ago

Triton python code

No response

Triton IR

No response

Crash log

No response

Additional information

My understanding of the modulo definition is that mod happens last in the address generation. E.g., for a 1D tensor of pointers, the addresses it generates are: (offsets[0] + i * strides[0]) % parent_sizes[0] for each i. Does that sound right to you?

If so, my question is regarding the handling of modulo in addState. Should we check that only one of the operands has modulo to be set? It looks like we perform related check in visitOperandAdd here, but the check is limited to rank-1 tensors only. Can you please explain the reasoning?

If my understanding is correct, I believe the pattern in this LIT test should fail because we are adding a scalar (which is represented by an offset in dimension 0) to %arg9 / %15, which has modulo in dimension 0.

nhat-nguyen commented 9 months ago

Happy new year Haishan! It's good you're lending another pair of eyes at this code again as I'm still not happy with the amount of complexity that the feature has :)

If so, my question is regarding the handling of modulo in addState. Should we check that only one of the operands has modulo to be set?

I agree. We have an assert here in addState:

https://github.com/microsoft/triton-shared/blob/3230526f95f7e73f1885b9c9d008b1aec535f2d2/lib/Analysis/PtrAnalysis.cpp#L102C1-L105C1

My intention is this should assert that for every dimension, only one of the operands has a modulo expression. This should match your expectation I think.

It looks like we perform related check in visitOperandAdd here, but the check is limited to rank-1 tensors only. Can you please explain the reasoning?

I think the assert here might look a little misleading on first glance:

if ((lhsState.getRank() == 1 && lhsState.hasModulo()) ||
      (rhsState.getRank() == 1 && rhsState.hasModulo())) {
    assert(0 && "Current do not support this pattern: a + arange(0, K) % M");
}

This reads as, if we're of rank 1, and if we already have an existing modulo expression in either lhs or rhs, then it means we must be dealing with something resembling the pattern a + (arange(0, K) % M) (here lhs is a and rhs is arange(0, K) % M which has a modulo).

The pattern we're looking for like you said is (offsets[0] + i * strides[0]) % parent_sizes[0] for each i -- the add is between offsets[0] and i * strides[0] which does not any modulo expression.

Any other cases, including when we're of rank 2, we just skip over the assert and let addState handle verifying that there are only at most one modulo expression in each dimension.

I believe the pattern in this LIT test should fail because we are adding a scalar (which is represented by an offset in dimension 0) to %arg9 / %15, which has modulo in dimension 0.

So the addptr we're interested in is

    %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>

%13 is the 2D tensors that contain a modulo expression in dimension 0, and %14 is the offset you're referring to.

%14 is just a broadcast of %arg0, so it has no modulo states for its two dimensions. %13 has a modulo state in dimension 0. So we satisfy the conditions in both addState and visitOperandAdd.

haishanzzzz commented 9 months ago

Happy New Year to you too Nhat! Thank you for your detailed answer!

Any other cases, including when we're of rank 2, we just skip over the assert and let addState handle verifying that there are only at most one modulo expression in each dimension.

Ah thanks. That makes sense to me. I missed that we also have a check in addState.

So the addptr we're interested in is...

Sorry I should have been more clear. The addptr I am referring to is:

%33 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>

Where %30 is a scalar, which I believe is represented as an offset on dim0. Should the analysis fail on this pattern?

nhat-nguyen commented 9 months ago

Sorry I should have been more clear. The addptr I am referring to is:

%33 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32> Where %30 is a scalar, which I believe is represented as an offset on dim0. Should the analysis fail on this pattern?

Ah no, the analysis should not fail on this pattern. %30 is a scalar and does not have any modulo expression involved, so as per the current code, we can take the modulo state directly from %arg9 during addState.

Reading the IR is hard, here's the rough triton code that I used to generate this lit test, hopefully it will make more sense:

@triton.jit
def wrap_stacked_masked_loop(
    a_ptr, c_ptr, M, N, stride_am, stride_an, stride_cm, stride_cn
):
    BLOCK_SIZE_K = 4
    offs_am = (2 + tl.arange(0, 4)) % M
    offs_an = 3 + tl.arange(0, 4)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an)

    offs_cm = tl.arange(0, 4)
    offs_cn = tl.arange(0, 4)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]

    for k in range(0, 2):
        a = tl.load(a_ptrs)
        tl.store(c_ptrs, a)
        a_ptrs += BLOCK_SIZE_K * stride_an
        c_ptrs += BLOCK_SIZE_K * stride_an

Since this is a "stacked wraparound" case, here's the visualization for the above code:

                     cols
             wrappedAroundOff
    --------------*------------*-------------------------------------------
    |        d2   |            |                         |            |   |
    |             |------------|                         |------------|   |
rows|                                                                     |
    |                                                                     |
    |           targetOffset                                              |
    |             *------------|                         *------------|   |
    |             |            |                         |            |   |
    |         d1  |            |   we move along this    |            |   |
    |             |            |          dim            |            |   |
    |             | clampedOff |  -------------------->  |            |   |
    --------------*--------------------------------------------------------
                  |  overflow  |
                  *-------------
               nextOff

The triton modulo patterns that we have seen always take the modulo on one dimension and move along the other dimension.

So in %33 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>, because %30 is just a scalar with no modulo information (but %arg9 does have it), we can reuse %arg9's info and be able to generate the same "stacked wraparound" offsets for the next blocks.

I hope this makes more sense, otherwise we can discuss this offline too. I have also recently added more modulo test cases over at https://github.com/microsoft/triton-shared/blob/main/python/examples/test_modulo.py. These can run on the CPU backend that Shintaro contributed so you can play around with it as well!

haishanzzzz commented 9 months ago

Thank you Nhat. Your explanation makes sense to me.

The only thing I want to make sure is that in your diagram, you are showing that the addptr performed in creating %33 moves the tensor horizontally, but in the IR %30 is a scalar and afaik the offset is recorded in dim0 (since at compiler level we can't tell which dimension the scalar value is incrementing, I believe the current code defaults to using dim0). Would that mean we are shifting vertically?

nhat-nguyen commented 9 months ago

I think there are a few key points that contribute to how this works, we are actually shifting horizontally as illustrated in the diagram even though the offset is recorded in dim0. We expect "stacked wraparound" to always shift horizontally and "side-by-side wraparound" to always shift vertically (of course, the user can decide to not increment the blocks like we expect; unfortunately, we can't really detect this and will produce non-sensical codegen in such cases).

1) when we rewrite the addptr, we determine which scenario (stacked or side-by-side) we are dealing with:

https://github.com/microsoft/triton-shared/blob/3230526f95f7e73f1885b9c9d008b1aec535f2d2/lib/Analysis/PtrAnalysis.cpp#L784C1-L793C1

2) then in both

createStackedCastOps https://github.com/microsoft/triton-shared/blob/3230526f95f7e73f1885b9c9d008b1aec535f2d2/lib/Analysis/PtrAnalysis.cpp#L148

and createSideBySideCastOps https://github.com/microsoft/triton-shared/blob/3230526f95f7e73f1885b9c9d008b1aec535f2d2/lib/Analysis/PtrAnalysis.cpp#L241

the formulas that we use to derive the wraparound point only depend on a single pointer offset (i.e.: all dimensional offsets are already collapsed, so it does not matter if we default to dim0 for scalars:

  Value targetOffset =
      ofrToIndexValue(accumulateTargetOffset(loc, rewriter), loc, rewriter);

This is also the reason why we have the following limitation:

  //////////////////////////////////////////////////////////////////////////////
  //
  // Handling side-by-side wraparound
  //
  // Note: We do not support cases where the target has already overflown the
  // number of columns! This is because in PtrAnalysis, the offset has already
  // been collapsed into a single dimension, so it is ambiguous to determine
  // whether the offset actually overflows or just refers to an element on the
  // subsequent rows.
  //
  // Same limitations apply to the stacked wraparound case.
  //
  //////////////////////////////////////////////////////////////////////////////

As long as the triton user writes the correct code to move the blocks horizontally, our modulo code will also move horizontally.

This whole wraparound shenanigans all boil down to:

Then, assuming the triton user writes the expected pattern that we support, in createStackedCastOps or createSideBySideCastOps, we generate the wraparound using the final pointer offset and the recorded modulo amount. The code doesn't really know if the user wants to shift horizontally or vertically, that semantic has already been baked in the final offset that we use to derive the wraparound points.

haishanzzzz commented 9 months ago

Thank you so much Nhat! This really clarifies!

We expect "stacked wraparound" to always shift horizontally and "side-by-side wraparound" to always shift vertically.

That's what I am missing! The code that you pointed at is doing exactly what you are saying.

Tbh I think this is also making me feel a little uneasy, because there is not going to be an easy way for us to detect if our assumption is true, even if we keep offset for each dimension (because of these spatting of constant into scalars) . E.g., in the LIT test we are discussing, if the user for some reason wanted shift vertically, our address calculation will produce wrong results. This is not a sane code though as you mentioned but nonetheless.

I just posted the proposal (#81) about teasing apart the current triton-to-linalg btw. Please feel free to take a look. I will probably take the same approach as you in dealing with this pattern there (and maybe add some warning). Any feedback from you is of course greatly appreciated!

Let me close this issue. Thank you again for explaining this to me!

nhat-nguyen commented 9 months ago

I'm glad the explanations help :) I agree the handling for modulo isn't great. We have a lot of assumptions baked in, and the code itself isn't simple to understand nor works for all cases either. Hopefully with the community we can improve this more. I will take a look at the proposal -- thanks for putting it up!