microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
134 stars 26 forks source link

Implement lowerings for argmin and argmax #58

Closed nhat-nguyen closed 7 months ago

nhat-nguyen commented 7 months ago

Triton argmin and argmax both lower to tt.reduce ops that have identical semantics identical to linalg.reduce op, so we can clone tt.reduce body to linalg.reduce directly. Unfortunately, we still need to perform pattern matching to know what reduce ops we are dealing with so that we know how to initialize the initial reduce values correctly.

We can do this in a generic way without pattern matching by always using the first elements along the reduction axis and perform the reduction on the remaining elements. However, this results in creatings sub-tensors that aren't always multiple of 2s, which are sub-optimal for certain hardware.