microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
134 stars 26 forks source link

Introduce triton-arith-to-linalg pass #85

Closed haishanzzzz closed 5 months ago

haishanzzzz commented 5 months ago

This PR introduces triton-to-structured pass. Please see #81 for background.

Most of the logic of the pass are directly copied from triton-to-linalg. The main differences are:

nhat-nguyen commented 5 months ago

@haishanzzzz We have a new pattern added recently in this PR as well: #86

nhat-nguyen commented 5 months ago

It would be great if we can reuse the patterns instead of duplicating them. I'm afraid it would be hard to maintain both versions as people add more patterns; we currently already have one outstanding PR that adds more pattern in the TritonToLinalg pass.

Sharing these patterns is a little tricky, but I spent some time experimenting and I think this would work for us with minimal effort:

This of course isn't the best since all of the implementation is now in a header file, but at least it saves us from having to duplicate all this code and consolidating them later. Another alternative would be to place all these patterns in another lib and have both TritonToLinalg and TritonArithToLinalg depend on it, but I think given that this is all temporary, it's not worth the trouble. After we retire the monolith pass, we can move this header back to a normal cpp file under TritonArithToLinalg.

Here's the rough diff for the approach I describe above: https://github.com/microsoft/triton-shared/commit/a2d16dfa89abf0b2f8ad5beb52eee4e1ba34fce3 Note that in the above diff I made TritonArithToLinalg reuse the old AddPtrConverter, but hopefully this illustrates what I'm trying to describe.

haishanzzzz commented 5 months ago

It would be great if we can reuse the patterns instead of duplicating them. I'm afraid it would be hard to maintain both versions as people add more patterns; we currently already have one outstanding PR that adds more pattern in the TritonToLinalg pass.

Sharing these patterns is a little tricky, but I spent some time experimenting and I think this would work for us with minimal effort:

  • copy all of the shared patterns to a header file, perhaps ConversionPatterns.hpp, and place it under triton/include/TritonArithToLinalg/ConversionPatterns.hpp

    • note that we still leave the specific patterns intact; I think the only difference is AddPtrConverter
  • remove all the pattern definitions in both TritonArithToLinalg.cpp and TritonToLinalg.cpp
  • now both lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp and lib/Conversion/TritonToLinalg/TritonToLinalg.cpp will need to include ConversionPatterns.hpp:

    • #include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp"

This of course isn't the best since all of the implementation is now in a header file, but at least it saves us from having to duplicate all this code and consolidating them later. Another alternative would be to place all these patterns in another lib and have both TritonToLinalg and TritonArithToLinalg depend on it, but I think given that this is all temporary, it's not worth the trouble. After we retire the monolith pass, we can move this header back to a normal cpp file under TritonArithToLinalg.

Here's the rough diff for the approach I describe above: a2d16df Note that in the above diff I made TritonArithToLinalg reuse the old AddPtrConverter, but hopefully this illustrates what I'm trying to describe.

Thank you for suggesting this Nhat and trying it out for me. This is a great idea. Will address before PR closes.

Please feel free to continue with the review in the mean time btw.

haishanzzzz commented 5 months ago

@haishanzzzz Thanks for the refactor. This mostly reuses what we already have so all look good to me.

I do have one question before closing regarding

Add support for tt.addptr op. This is to tackle cases when the tensors of pointers does not have any structure and we have to materialize the tensor.

So after triton-to-structured, the following types of tt.addptr may exist:

  1. those that deal with scalars
  2. those that we can't analyze

So in this pass, we have the option of converting tt.addptr to linalg, but looking at AddPtrConverter, I don't see us filtering out scalar addptr at all. Should we add this?

Thank you for the review @nhat-nguyen!

We filter for non-scalar addptr in TritonArithToLinalgPass.cpp with the following:

  if (addptrToLinalg) {
      target.addDynamicallyLegalOp<triton::AddPtrOp>([](triton::AddPtrOp op) {
        return !op.getResult().getType().isa<ShapedType>();
      });
    }
aaronsm commented 5 months ago

I believe this change has broken the build due to hardcodes triton-shared directory names. See issue https://github.com/microsoft/triton-shared/issues/91 for more details.

nhat-nguyen commented 5 months ago

@aaronsm Would you mind sharing more details? Issue #91 was opened 3 days before but we only merged this yesterday. The failure that you see is expected as triton had a big refactor in how 3rd party plugins work. I'm working on a fix right now. I also left a comment in the linked issue.