punica-ai / punica

Serving multiple LoRA finetuned LLM as one
https://arxiv.org/abs/2310.18547
Apache License 2.0
883 stars 40 forks source link

Fixed deadlock in sgmv_shrink kernel caused by skewed segments #35

Closed tgaddair closed 5 months ago

tgaddair commented 5 months ago

When there is a large imbalance (>= 65 elements in the batch) in the size of two or more segments in a batch, it can lead to deadlocks in the sgmv_shrink kernel.

The crux of the issue was that each grid block can execute a dynamic number of steps depending on the size of its segment (s_end - s_start). However, during each step the block will call grid.sync(). If one block executes more steps than another, it will call grid.sync() a different number of times, leading to a deadlock.

The solution presented here is to compute the max number of steps from the largest segment, and then call grid.sync() at the end of the kernel for the difference between the max steps and the current block's steps.

Because the length of the s vector is generally very small (< batch size), the loop here should not introduce noticeable latency. However, it may be worth exploring more optimized solutions to this problem in a follow-up.

Note that this issue only occurs when using cooperative groups.

Related:

tgaddair commented 5 months ago

cc @abcdabcd987

codecov[bot] commented 5 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Comparison is base (07a40b9) 43.27% compared to head (2a32749) 43.27%. Report is 1 commits behind head on master.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #35 +/- ## ======================================= Coverage 43.27% 43.27% ======================================= Files 10 10 Lines 647 647 ======================================= Hits 280 280 Misses 367 367 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

abcdabcd987 commented 5 months ago

Thanks!

@yzh119 Can you take a look?