microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
134 stars 26 forks source link

Add support for torch inductor modulo op pattern #68

Closed nhat-nguyen closed 7 months ago

nhat-nguyen commented 7 months ago

This PR addresses several issues:

  1. Support modulo pattern generated by torch inductor We now support tl.arange(0, size)[:, None] % mod -- expanding the shape before applying the modulo). Fixes #14 #48.

  2. Add more modulo tests running on the CPU backend, I found out that our current usages of memref.reinterpret_cast to support modulo ops in a loop are incorrect. Previously we insert two "no-op" memref.reinterpret_cast for the two blocks of the mod pointers so that LoadConveter can determine the sizes of the blocks to copy to the local buffers. However, when lowering all the way to llvm, doing this meant that we are resetting the offset of the blocks being yielded in each loop interation. To solve this, I have replaced the casts with the proper memref.dim_op to get the correct sizes.

  3. Fix individual modulo block's type can sometimes mismatch in a loop Previously, the types for each individual modulo block can have static strides. During a loop, their corresponding loop's yield values have dynamic strides, causing type mismatch. I have instead make the strides always dynamic to begin with.

  4. Support lowering to CPU for more cases Lowering to memref can produces more affine ops which we would have already run in the current pass ordering. I have added two additional passes in the pass list to fix this issue.

  5. Add softmax tutorial test for CPU backend