ROCm / triton

Development repository for the Triton language and compiler
MIT License
87 stars 27 forks source link

tl.load: Support passing tl.constexpr(SOME_TUPLE) to boundary_check #517

Open xinyazhang opened 7 months ago

xinyazhang commented 7 months ago

We would like to use tl.load's boundary_check in this way:

if CONDITION1:
    k_boundary = (1,0)
elif CONDITION2:
    k_boundary = (0,)
else:
    k_boundary = None
k = tl.load(K_block_ptr, boundary_check=k_boundary, padding_option="zero")

However this is not possible because (1,0) will be translated to tl.tensor. Although it's possible to prevent such translation by using tl.constexpr((1,0)). The result is still not ideal: k_boundary becomes [constexpr[[constexpr[1], constexpr[0]]]] and cannot be translated by current _canonicalize_boundary_check.

This PR handles this case and improves the diagnostic information provided by the assert below.