Closed tgaddair closed 6 months ago
Thanks! We'll investigate.
I think the issue (or at least a part of it) stems from the fact that the current Torch stream is not used. This would prevent graph capture through Torch APIs.
Thanks for the suggestion, @Yard1! I'm also running into this issue with the legacy cutlass kernel (still needed for expand). I tried specifying the kernel to use the current stream from PyTorch, but that doesn't seem to have resolved the issue. Let me know if I'm missing something:
diff --git a/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh
index 69e180c..1747ce6 100644
--- a/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh
+++ b/server/punica_kernels/punica_kernels/sgmv/sgmv_cutlass.cuh
@@ -1,4 +1,5 @@
#pragma once
+#include <ATen/cuda/CUDAContext.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
@@ -63,6 +64,7 @@ template <typename DType>
bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx) {
using cutlass_t = typename cutlass_dtype<DType>::type;
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto ptr_Y = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
auto ptr_X = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
@@ -73,7 +75,7 @@ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
auto all_problems =
alloc_from_buf<cutlass::gemm::GemmCoord>(&tmp_d, num_problems);
- precompute_sgmv_args<<<num_problems, 1>>>(
+ precompute_sgmv_args<<<num_problems, 1, 0, stream>>>(
all_problems, ptr_Y, ptr_X, ptr_W, ld_Y, ld_X, ld_W, (cutlass_t *)y,
(cutlass_t *)x, (cutlass_t **)w, s_start, s_end, d_in, d_out, layer_idx);
@@ -112,7 +114,7 @@ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
ptr_Y, ld_X, ld_W, ld_Y, ld_Y);
GemmGrouped gemm;
- auto status = gemm.initialize(args);
+ auto status = gemm.initialize(args, nullptr, stream);
if (status != cutlass::Status::kSuccess) {
fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n",
cutlassGetStatusString(status));
@@ -157,7 +159,7 @@ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
ptr_Y, ld_X, ld_W, ld_Y, ld_Y);
GemmGrouped gemm;
- auto status = gemm.initialize(args);
+ auto status = gemm.initialize(args, nullptr, stream);
if (status != cutlass::Status::kSuccess) {
fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n",
cutlassGetStatusString(status));
Really nice catch @Yard1 .
@tgaddair Can you try https://github.com/punica-ai/punica/commit/b5f5e1a3aa46702c843e8815acf8f654a242e109 ? pytest -v -k cuda_graph
. This commit fixes the stream issue. But I think we need more work to make cuda graph run smoothly on SGMV, because the input is very dynamic. Maybe related: https://pytorch.org/docs/master/notes/cuda.html#partial-network-capture
Thanks for the quick fix, @abcdabcd987! I can verify that the changes in your commit appears to have resolved the issue (graph can trace and replay) for both the shrink and cutlass sgmv kernels. One thing to note, though, is that the cutlass kernel still raises an error for me if run with CUDA_LAUNCH_BLOCKING=1
(though shrink no longer fails). I was setting this for debugging, but since I don't use it in production, this should hopefully not be an issue, but worth calling out.
Regarding the dynamic input: my plan there was to set the segment index tensor to have shape [batch_size + 1]
instead of [num_segments + 1]
, and pad the tensor with -1. Then make a small change within the kernel to short-circuit if the segment value is -1. That way we always run with a fixed batch size, but potentially variable number of segments. Let me know if you see any obvious issues with that approach.
Hey @abcdabcd987, through further testing, the shrink kernel holds up well, but it looks like there are in fact some issues with the cutlass sgmv kernel. I'm trying to put together a minimal repro of the issue, but it looks like under certain situations the sgmv cutlass kernel will be a no-op when running during graph replay, while the same inputs run outside the graph will work as expected. This seems to have something to do with interaction between other operations in the graph, as I haven't been able to repro it using a graph that only consists of the sgmv ops.
I can confirm that CUDA_LAUNCH_BLOCKING=1 pytest -v -k cuda_graph
can trigger the issue. The problem comes from that we have a kernel, precompute_sgmv_args
, that computes input arguments for cutlass. And sadly I forgot to add stream to it in my previous commit.
Can you try https://github.com/punica-ai/punica/commit/07a40b9d30e98d88963e8a7e140120a25ac0d518 ? CUDA_LAUNCH_BLOCKING=1 pytest -v -k cuda_graph
looks good this time.
Ah, that fixed it! Thanks @abcdabcd987. I can confirm that both CUDA_LAUNCH_BLOCKING=1
and replay issues are resolved with these changes. I'll close out this issue!
Awesome!
When attempting to run the
sgmv_shrink
kernel as part of a CUDA graph, an error will occur. Digging into it further, it seems that some of the operations performed in the kernel are not supported by CUDA graphs. I haven't dug too deeply into it, but going through the code, I wonder if it might be related to memcpy or other (presumably) blocking calls.Example:
Thanks for the great work on this project! We're huge fans at LoRAX.