S-LoRA / S-LoRA

S-LoRA: Serving Thousands of Concurrent LoRA Adapters
https://arxiv.org/abs/2311.03285
Apache License 2.0
1.74k stars 94 forks source link

Question about cuda kernel #10

Closed harryhan618 closed 11 months ago

harryhan618 commented 11 months ago

Hello! Thanks for open-sourcing the repository! I'm learning to write cuda code. So I think I learnt a lot here. I have two questions.

  1. In your paper 5.3, you mentioned the cuda kernels are different between prefilling stage and decode stage. I wander why doing this? I seems that bgmv still do the work if shape [bs, seqlen, dim] becomes to [bs * seqlen, dim].

  2. This question is a much detailed about cuda kernel implementation. In bgmv_multi_lora_rank_expand_kernel line #202, the feat_in dimension is reduced through __shfl_down_sync. I think __shfl_down_sync only works within one warp. Is this due to the fact feat_in is relatively small and will not exceed one warp (ie. 32 threads). It also means feat_in should not exceed 32*vec_size (for fp16, vec_size is 8) ?

caoshiyi commented 11 months ago

Thanks for your question!

  1. You are absolutely right that bgmv also works for prefill stage. However, in prefill stage, we do want to do tiling on the seq dimension to increase operational intensity. BGMV works well when bs*seq_len is small (in decode stage where each request only has one token). When bs*seq_len is large, we essentially want something like CUTLASS's Grouped Gemm kernel but it is non-trivial to modify it to support non-contiguous single adapter weights to align with our memory pool. Therefore we tried Triton. It is now a temporary workaround: we benchmarked it and the triton kernel now cannot actually outperform cutlass's grouped gemm. Improving this is on our TODO list.
  2. You are right about this! Since for LoRA adapters we normally have very small r (e.g., 16 and 64 for llama-7b), we believe one warp is enough.
jcao-ai commented 11 months ago

@Ying1123 @caoshiyi Thanks for your explanation. Another question is that does bgmv outperform triton kernel on decoding phase ? (Since I saw that during decoding the only kernel used is bgmv) Thanks.