Triton argmin and argmax both lower to tt.reduce ops that have identical
semantics identical to linalg.reduce op, so we can clone tt.reduce body to
linalg.reduce directly. Unfortunately, we still need to perform pattern matching
to know what reduce ops we are dealing with so that we know how to initialize
the initial reduce values correctly.
We can do this in a generic way without pattern matching by always using
the first elements along the reduction axis and perform the reduction on
the remaining elements. However, this results in creatings sub-tensors that
aren't always multiple of 2s, which are sub-optimal for certain hardware.
Triton
argmin
andargmax
both lower tott.reduce
ops that have identical semantics identical to linalg.reduce op, so we can clonett.reduce
body tolinalg.reduce
directly. Unfortunately, we still need to perform pattern matching to know what reduce ops we are dealing with so that we know how to initialize the initial reduce values correctly.We can do this in a generic way without pattern matching by always using the first elements along the reduction axis and perform the reduction on the remaining elements. However, this results in creatings sub-tensors that aren't always multiple of 2s, which are sub-optimal for certain hardware.