punica-ai / punica

Serving multiple LoRA finetuned LLM as one
https://arxiv.org/abs/2310.18547
Apache License 2.0
883 stars 40 forks source link

bug: sgmv_shrink does not support CUDA graph tracing #33

Closed tgaddair closed 6 months ago

tgaddair commented 6 months ago

When attempting to run the sgmv_shrink kernel as part of a CUDA graph, an error will occur. Digging into it further, it seems that some of the operations performed in the kernel are not supported by CUDA graphs. I haven't dug too deeply into it, but going through the code, I wonder if it might be related to memcpy or other (presumably) blocking calls.

Example:

torch.cuda.synchronize(device)

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=None):
    sgmv_shrink(...)

torch.cuda.synchronize(device)

Thanks for the great work on this project! We're huge fans at LoRAX.

abcdabcd987 commented 6 months ago

Thanks! We'll investigate.

Yard1 commented 6 months ago

I think the issue (or at least a part of it) stems from the fact that the current Torch stream is not used. This would prevent graph capture through Torch APIs.

tgaddair commented 6 months ago

Thanks for the suggestion, @Yard1! I'm also running into this issue with the legacy cutlass kernel (still needed for expand). I tried specifying the kernel to use the current stream from PyTorch, but that doesn't seem to have resolved the issue. Let me know if I'm missing something:

diff --git a/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh
index 69e180c..1747ce6 100644
--- a/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh
+++ b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh
@@ -1,4 +1,5 @@
 #pragma once
+#include <ATen/cuda/CUDAContext.h>
 #include <cuda_bf16.h>
 #include <cuda_fp16.h>
 #include <cuda_runtime.h>
@@ -63,6 +64,7 @@ template <typename DType>
 bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end, 
           void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx) {
   using cutlass_t = typename cutlass_dtype<DType>::type;
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

   auto ptr_Y = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
   auto ptr_X = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
@@ -73,7 +75,7 @@ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
   auto all_problems =
       alloc_from_buf<cutlass::gemm::GemmCoord>(&tmp_d, num_problems);

-  precompute_sgmv_args<<<num_problems, 1>>>(
+  precompute_sgmv_args<<<num_problems, 1, 0, stream>>>(
       all_problems, ptr_Y, ptr_X, ptr_W, ld_Y, ld_X, ld_W, (cutlass_t *)y,
       (cutlass_t *)x, (cutlass_t **)w, s_start, s_end, d_in, d_out, layer_idx);

@@ -112,7 +114,7 @@ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
                                          ptr_Y, ld_X, ld_W, ld_Y, ld_Y);

     GemmGrouped gemm;
-    auto status = gemm.initialize(args);
+    auto status = gemm.initialize(args, nullptr, stream);
     if (status != cutlass::Status::kSuccess) {
       fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n",
               cutlassGetStatusString(status));
@@ -157,7 +159,7 @@ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
                                          ptr_Y, ld_X, ld_W, ld_Y, ld_Y);

     GemmGrouped gemm;
-    auto status = gemm.initialize(args);
+    auto status = gemm.initialize(args, nullptr, stream);
     if (status != cutlass::Status::kSuccess) {
       fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n",
               cutlassGetStatusString(status));
abcdabcd987 commented 6 months ago

Really nice catch @Yard1 .

@tgaddair Can you try https://github.com/punica-ai/punica/commit/b5f5e1a3aa46702c843e8815acf8f654a242e109 ? pytest -v -k cuda_graph. This commit fixes the stream issue. But I think we need more work to make cuda graph run smoothly on SGMV, because the input is very dynamic. Maybe related: https://pytorch.org/docs/master/notes/cuda.html#partial-network-capture

tgaddair commented 6 months ago

Thanks for the quick fix, @abcdabcd987! I can verify that the changes in your commit appears to have resolved the issue (graph can trace and replay) for both the shrink and cutlass sgmv kernels. One thing to note, though, is that the cutlass kernel still raises an error for me if run with CUDA_LAUNCH_BLOCKING=1 (though shrink no longer fails). I was setting this for debugging, but since I don't use it in production, this should hopefully not be an issue, but worth calling out.

Regarding the dynamic input: my plan there was to set the segment index tensor to have shape [batch_size + 1] instead of [num_segments + 1], and pad the tensor with -1. Then make a small change within the kernel to short-circuit if the segment value is -1. That way we always run with a fixed batch size, but potentially variable number of segments. Let me know if you see any obvious issues with that approach.

tgaddair commented 6 months ago

Hey @abcdabcd987, through further testing, the shrink kernel holds up well, but it looks like there are in fact some issues with the cutlass sgmv kernel. I'm trying to put together a minimal repro of the issue, but it looks like under certain situations the sgmv cutlass kernel will be a no-op when running during graph replay, while the same inputs run outside the graph will work as expected. This seems to have something to do with interaction between other operations in the graph, as I haven't been able to repro it using a graph that only consists of the sgmv ops.

abcdabcd987 commented 6 months ago

I can confirm that CUDA_LAUNCH_BLOCKING=1 pytest -v -k cuda_graph can trigger the issue. The problem comes from that we have a kernel, precompute_sgmv_args, that computes input arguments for cutlass. And sadly I forgot to add stream to it in my previous commit.

Can you try https://github.com/punica-ai/punica/commit/07a40b9d30e98d88963e8a7e140120a25ac0d518 ? CUDA_LAUNCH_BLOCKING=1 pytest -v -k cuda_graph looks good this time.

tgaddair commented 6 months ago

Ah, that fixed it! Thanks @abcdabcd987. I can confirm that both CUDA_LAUNCH_BLOCKING=1 and replay issues are resolved with these changes. I'll close out this issue!

abcdabcd987 commented 6 months ago

Awesome!