ROCm / triton

Development repository for the Triton language and compiler
MIT License
83 stars 27 forks source link

Dot slicing pass #440

Closed oplavsic closed 7 months ago

oplavsic commented 8 months ago

This PR implements the dot slicing pass. It slices along the k dimension in a dot operation (for dot A, B with shape(A) = (m,k), shape(B) = (k,n)). The slice size along the k dimension is set by 'sliceKTile'. The algorithm includes three main steps:

1) Modify load instruction layouts for dot operands.

2) Slice the dot operands.

3) Slice the dot operation along the k dimension.

NOTE: Currently, this pass is only able to slice first gemm in flash attention. After adding support for MFMA layout to view_slice instruction, dot slicing will be (trivially) updated to support slicing for second gemm as well.

oplavsic commented 8 months ago

@zhanglx13 @jayfurmanek I will add some lit tests to this PR and start working on supporting second Gemm, but it is mostly ready to be reviewed. Currently performance is a bit lower than what we have without slicing. I will do assembly investigation and work on code scheduling after I support slicing for second gemm as well. Good news is that performance is better than what we got when slicing on python level, and register pressure increase is not as severe. I will add performance numbers and more details once I start working on performance analysis. You can play around with this PR by using slice_k_tile kernel parameter to enable this pass for GEMM and flash attention (for example, slice_k_tile=32).

zhanglx13 commented 8 months ago

Thank you @oplavsic. I'll take a look later some time.

oplavsic commented 8 months ago

@zhanglx13 @jayfurmanek I will add some lit tests to this PR and start working on supporting second Gemm, but it is mostly ready to be reviewed. Currently performance is a bit lower than what we have without slicing. I will do assembly investigation and work on code scheduling after I support slicing for second gemm as well. Good news is that performance is better than what we got when slicing on python level, and register pressure increase is not as severe. I will add performance numbers and more details once I start working on performance analysis. You can play around with this PR by using slice_k_tile kernel parameter to enable this pass for GEMM and flash attention (for example, slice_k_tile=32).

@zhanglx13 @jayfurmanek I added support for chained dot slicing (second gemm in FA). It seems that triton-opt tool can't pass values to passes options (such as slice-k-tile) and always uses a default value, which prevents me to add lit tests at this point (I will talk to upstream guys on how they handle this). I think this PR is ready for review at this point. I will now focus on improving scheduling pass (ReorderInstructions.cpp) to handle dot slicing and squeeze the best performance out of it.

zhanglx13 commented 8 months ago

@oplavsic Can you add some lit tests to check if blocked layout can be correctly sliced. Then I think PR is good to go.

oplavsic commented 7 months ago

@oplavsic Can you add some lit tests to check if blocked layout can be correctly sliced. Then I think PR is good to go.

Done.