Closed nhat-nguyen closed 4 months ago
@haishanzzzz I still can't add you as a reviewer so tagging you here instead.
@haishanzzzz I still can't add you as a reviewer so tagging you here instead.
Thank you @nhat-nguyen. For some reason Github told me the invite didn't take effect.. Will review early next week!
This PR adds the final pass
structured-to-memref
as outlined in the discussion on making thetriton-to-linalg
pass more modular: https://github.com/microsoft/triton-shared/discussions/81Details
The pass expects to run after
triton-to-structured
andtriton-arith-to-linalg
and will lower all ops in the TritonStructured dialect to memref:tts.make_tptr
tomemref.reinterpret_cast
tts.load
to a combination ofmemref.alloc
,memref.copy
, andbufferization.to_tensor
tts.store
tobufferization.materialize_in_destination
Differences
Overall, the final resulting IRs will be cleaner compared to the original monolith
triton-to-linalg
pass with very little changes except for handling of scalar loads and stores.Since
triton-to-structured
leaves outtt.addptr
,tt.load
, andtt.store
for unstructured pointers and scalars, the pass also handles those that deal with scalars through a combination ofmemref.reinterpret_cast
andmemref.extract_strided_metadata
as seen below:will become
How to use
With this pass, the whole pipeline for lowering triton IR to linalg is:
Testing
lit tests
I have added all the existing lit tests from
triton-to-linalg
and manually verified that the outputs make sense.CPU backend testing
I have updated our CPU backend to use the new pipeline and verified that they all ran correctly.