Closed haishanzzzz closed 5 months ago
@haishanzzzz We have a new pattern added recently in this PR as well: #86
It would be great if we can reuse the patterns instead of duplicating them. I'm afraid it would be hard to maintain both versions as people add more patterns; we currently already have one outstanding PR that adds more pattern in the TritonToLinalg pass.
Sharing these patterns is a little tricky, but I spent some time experimenting and I think this would work for us with minimal effort:
ConversionPatterns.hpp
, and place it under triton/include/TritonArithToLinalg/ConversionPatterns.hpp
AddPtrConverter
TritonArithToLinalg.cpp
and TritonToLinalg.cpp
lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp
and lib/Conversion/TritonToLinalg/TritonToLinalg.cpp
will need to include ConversionPatterns.hpp
:
#include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp"
This of course isn't the best since all of the implementation is now in a header file, but at least it saves us from having to duplicate all this code and consolidating them later. Another alternative would be to place all these patterns in another lib and have both TritonToLinalg
and TritonArithToLinalg
depend on it, but I think given that this is all temporary, it's not worth the trouble. After we retire the monolith pass, we can move this header back to a normal cpp file under TritonArithToLinalg
.
Here's the rough diff for the approach I describe above:
https://github.com/microsoft/triton-shared/commit/a2d16dfa89abf0b2f8ad5beb52eee4e1ba34fce3
Note that in the above diff I made TritonArithToLinalg
reuse the old AddPtrConverter
, but hopefully this illustrates what I'm trying to describe.
It would be great if we can reuse the patterns instead of duplicating them. I'm afraid it would be hard to maintain both versions as people add more patterns; we currently already have one outstanding PR that adds more pattern in the TritonToLinalg pass.
Sharing these patterns is a little tricky, but I spent some time experimenting and I think this would work for us with minimal effort:
copy all of the shared patterns to a header file, perhaps
ConversionPatterns.hpp
, and place it undertriton/include/TritonArithToLinalg/ConversionPatterns.hpp
- note that we still leave the specific patterns intact; I think the only difference is
AddPtrConverter
- remove all the pattern definitions in both
TritonArithToLinalg.cpp
andTritonToLinalg.cpp
now both
lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp
andlib/Conversion/TritonToLinalg/TritonToLinalg.cpp
will need to includeConversionPatterns.hpp
:
#include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp"
This of course isn't the best since all of the implementation is now in a header file, but at least it saves us from having to duplicate all this code and consolidating them later. Another alternative would be to place all these patterns in another lib and have both
TritonToLinalg
andTritonArithToLinalg
depend on it, but I think given that this is all temporary, it's not worth the trouble. After we retire the monolith pass, we can move this header back to a normal cpp file underTritonArithToLinalg
.Here's the rough diff for the approach I describe above: a2d16df Note that in the above diff I made
TritonArithToLinalg
reuse the oldAddPtrConverter
, but hopefully this illustrates what I'm trying to describe.
Thank you for suggesting this Nhat and trying it out for me. This is a great idea. Will address before PR closes.
Please feel free to continue with the review in the mean time btw.
@haishanzzzz Thanks for the refactor. This mostly reuses what we already have so all look good to me.
I do have one question before closing regarding
Add support for tt.addptr op. This is to tackle cases when the tensors of pointers does not have any structure and we have to materialize the tensor.
So after
triton-to-structured
, the following types oftt.addptr
may exist:
- those that deal with scalars
- those that we can't analyze
So in this pass, we have the option of converting
tt.addptr
to linalg, but looking atAddPtrConverter
, I don't see us filtering out scalaraddptr
at all. Should we add this?
Thank you for the review @nhat-nguyen!
We filter for non-scalar addptr in TritonArithToLinalgPass.cpp with the following:
if (addptrToLinalg) {
target.addDynamicallyLegalOp<triton::AddPtrOp>([](triton::AddPtrOp op) {
return !op.getResult().getType().isa<ShapedType>();
});
}
I believe this change has broken the build due to hardcodes triton-shared directory names. See issue https://github.com/microsoft/triton-shared/issues/91 for more details.
@aaronsm Would you mind sharing more details? Issue #91 was opened 3 days before but we only merged this yesterday. The failure that you see is expected as triton had a big refactor in how 3rd party plugins work. I'm working on a fix right now. I also left a comment in the linked issue.
This PR introduces
triton-to-structured
pass. Please see #81 for background.Most of the logic of the pass are directly copied from
triton-to-linalg
. The main differences are:tt.addptr
op. This is to tackle cases when the tensors of pointers does not have any structure and we have to materialize the tensor.tt.func
->func.func
tt.get_program_id
-> function argumentstt.assert
->cf.assert