Open harryhan618 opened 7 months ago
Hmm... That's interesting... BTW, thanks for providing this script. Super helpful for reproducing the bug!
We'll take a look at this. In the meanwhile, you can use punica.ops.sgmv()
for SGMV-shrink and punica.ops.add_lora_sgmv_custom_cutlass()
for LoRA. Note that our custom kernel assumes column major weight whereas our cutlass kernel assumes row major weight.
The following works:
import torch
import punica.ops
bs = 4
h1 = 4096
h2 = 32
num_layers = 1
dtype = torch.float16
device = torch.device("cuda:0")
problem_sizes = [2, 2]
w = [
torch.ones((num_layers, h2, h1), dtype=dtype, device=device)
for _ in range(len(problem_sizes))
]
w_ptr = torch.tensor([t.data_ptr() for t in w],
dtype=torch.int64,
device=device)
s = torch.cumsum(
torch.tensor([0] + problem_sizes, device=device),
dim=0,
dtype=torch.int32)
x = torch.ones((s[-1], h1), dtype=dtype, device=device)
y = torch.zeros((s[-1], h2), dtype=dtype, device=device)
punica.ops.sgmv(y, x, w_ptr, s, layer_idx=0)
print(y)
Thanks for your reply! I'm curious that why do you choose column major weight? My basic understanding is that row major is friendly for data loading. Sorry I haven't read the kernel code yet.
@harryhan618 modern GPUs support transpose at fragment level (with ldmatrix.***.trans
/movmatrix
instructions) at very low cost, so there should not be a significant performance difference between column major & row major layout.
We will support row-major for shrink kernel in the next release.
@abcdabcd987 @yzh119
I also met the case that kernel launch fails under rank == 64
for sgmv_shrink
usage:
import torch
import punica.ops
bs = 1
h1 = 1024
h2 = 64
num_layers = 32
dtype = torch.float16
device = torch.device("cuda:0")
problem_sizes = [1]
w = [
torch.randn((num_layers, h2, h1), dtype=dtype, device=device)
for _ in range(len(problem_sizes))
]
w_ptr = torch.tensor([t.data_ptr() for t in w],
dtype=torch.int64,
device=device)
s = torch.cumsum(
torch.tensor([0] + problem_sizes, device=device),
dim=0,
dtype=torch.int32)
x = torch.ones((s[-1], h1), dtype=dtype, device=device)
y = torch.zeros((s[-1], h2), dtype=dtype, device=device)
# punica.ops.sgmv_cutlass(y, x, w_ptr, s, layer_idx=0)
punica.ops.sgmv(y, x, w_ptr, s, layer_idx=0)
print(y)
Output:
RuntimeError: No suitable kernel. dtype=Half d_out=64
@abcdabcd987 @yzh119 I also met the case that kernel launch fails under
rank == 64
forsgmv_shrink
usage:import torch import punica.ops bs = 1 h1 = 1024 h2 = 64 num_layers = 32 dtype = torch.float16 device = torch.device("cuda:0") problem_sizes = [1] w = [ torch.randn((num_layers, h2, h1), dtype=dtype, device=device) for _ in range(len(problem_sizes)) ] w_ptr = torch.tensor([t.data_ptr() for t in w], dtype=torch.int64, device=device) s = torch.cumsum( torch.tensor([0] + problem_sizes, device=device), dim=0, dtype=torch.int32) x = torch.ones((s[-1], h1), dtype=dtype, device=device) y = torch.zeros((s[-1], h2), dtype=dtype, device=device) # punica.ops.sgmv_cutlass(y, x, w_ptr, s, layer_idx=0) punica.ops.sgmv(y, x, w_ptr, s, layer_idx=0) print(y)
Output:
RuntimeError: No suitable kernel. dtype=Half d_out=64
NVM, I found this is related to shared memory. PR: https://github.com/punica-ai/punica/pull/20
Hi, any updates on why cutlass group gemmed calculate wrong results?
Hi, any updates on why cutlass group gemmed calculate wrong results?
I just added a few test cases. https://github.com/punica-ai/punica/commit/0c7cf81b0238674be53d4b582ea3b926ae0f09a5
Cutlass only has this problem for shrink. Since we are deprecating cutlass shrink, we probably won't fix this. Before our custom expand lands, you can use punica.add_lora_sgmv_custom_cutlass()
for LoRA.
Hi lequn, I think I found the bug of cutlass_shrink.
Please first see cutlass example 24 group gemm. The second parameter for LinearCombination
should 128 / cutlass::sizeof_bits<ElementOutput>::value
. For dtype float16, this should be 8.
(Although I don't know why this formula?)
In your code, for shrink, you wrote 4. I think this should be bug. For expand, you wrote 8, which is correct.
By the way, to make the code correctly compiled, I have to change Thread Block Shape and Warp Shape to be GemmShape<16, 128, 64>
and GemmShape<16, 32, 64>
.
So I'm also wondering how to choose these shape? Since that's the key difference between shrink and expand. I'm looking forward to see your insight!
Thank you!
I'm running the following code and find the answer goes wrong. I initialize the
x
andw
to be all ones. So the outputy
value should beh1=4096
.But my output is not. Half of the output is 4096 and the other half is 2528. Weird! My observation is that the wrong answer happens when h2>=32 for shrink.
The following code is adapted from
benchmarks/bench_sgmv_cutlass.py