microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
165 stars 34 forks source link

Support Block Pointer in triton-to-structured #90

Closed haishanzzzz closed 8 months ago

haishanzzzz commented 8 months ago

This PR enhances PtrAnalysis logic to support tt.make_tensor_ptr and tt.advance.

In addition to supporting the ops, this PR makes the following changes:

order gives us a way to differentiate block pointer type, so later passes can lower it accordingly.

Testing

Tested using the existing block_ptr_advance.mlir.

nhat-nguyen commented 8 months ago

I have a question regarding the order field. I notice that it's only being used to distinguish between a wraparound tensor and a block pointer tensor. Will this field come into play later (e.g.: in structured-to-memref)?

Edited: From the triton tutorial, looks like this extra information to describe how this is laid out in memory:

Note that the order argument is set to (1, 0), which means the second axis is the inner dimension in terms of storage, and the first axis is the outer dimension. This information may sound redundant, but it is necessary for some hardware backends to optimize for better performance.

We aren't really handling this at all in the old TritonToLinalg pass, so I guess this means we just assume the regular row major storage -- does this match your understanding?

nhat-nguyen commented 8 months ago

I think this approach makes sense, we reuse tts::MakeTensorPtrOp and store additional info to also represent a block pointer created by triton::MakeTensorPtrOp.

triton::MakeTensorPtrOp offsets are in terms of rows / columns (no strides multiplied). And because our tts.make_tptr offsets already have respective strides multiplied, we have to scale these offsets when handling tt.advance.

My only concern is using the order filed to distinguish between a wraparound ptr and block ptr can make the logic a little hard to follow; its name doesn't really convey what it really means as I had to refer back to the official triton tutorial. The same can be said for sizes too.

I think we don't need to follow the triton naming convention here too closely and can probably just keep parentSizes for both cases. order can probably be renamed to storageOrder or something.

haishanzzzz commented 8 months ago

Thanks for the careful review @nhat-nguyen. Will address the nits.

We aren't really handling this at all in the old TritonToLinalg pass, so I guess this means we just assume the regular row major storage -- does this match your understanding?

It seems that the prior implementation supports only row major order. See this.

My only concern is using the order filed to distinguish between a wraparound ptr and block ptr can make the logic a little hard to follow; its name doesn't really convey what it really means as I had to refer back to the official triton tutorial. The same can be said for sizes too.

I think we don't need to follow the triton naming convention here too closely and can probably just keep parentSizes for both cases. order can probably be renamed to storageOrder or something.

I am neutral on the naming. The reason I chose the same names is to make it lower the bar to understand these ops/fields, since people who looked at these will definitely have looked at Triton ops. Do you think that makes sense?