punica-ai / punica

Serving multiple LoRA finetuned LLM as one
https://arxiv.org/abs/2310.18547
Apache License 2.0
883 stars 40 forks source link

sgmv_cutlass calculate wrong output #11

Open harryhan618 opened 7 months ago

harryhan618 commented 7 months ago

I'm running the following code and find the answer goes wrong. I initialize the x and w to be all ones. So the output y value should be h1=4096.

But my output is not. Half of the output is 4096 and the other half is 2528. Weird! My observation is that the wrong answer happens when h2>=32 for shrink.

The following code is adapted from benchmarks/bench_sgmv_cutlass.py

import torch
import punica.ops

bs = 4
h1 = 4096
h2 = 32
num_layers = 1
dtype = torch.float16
device = torch.device("cuda:0")
problem_sizes = [2, 2]

w = [
      torch.ones((num_layers, h1, h2), dtype=dtype, device=device)
      for _ in range(len(problem_sizes))
  ]
w_ptr = torch.tensor([t.data_ptr() for t in w],
                     dtype=torch.int64,
                     device=device)
s = torch.cumsum(
    torch.tensor([0] + problem_sizes, device=device),
    dim=0,
    dtype=torch.int32)
x = torch.ones((s[-1], h1), dtype=dtype, device=device)
y = torch.zeros((s[-1], h2), dtype=dtype, device=device)
punica.ops.sgmv_cutlass(y, x, w_ptr, s, layer_idx=0)

print(y)
abcdabcd987 commented 7 months ago

Hmm... That's interesting... BTW, thanks for providing this script. Super helpful for reproducing the bug!

We'll take a look at this. In the meanwhile, you can use punica.ops.sgmv() for SGMV-shrink and punica.ops.add_lora_sgmv_custom_cutlass() for LoRA. Note that our custom kernel assumes column major weight whereas our cutlass kernel assumes row major weight.

The following works:

import torch
import punica.ops

bs = 4
h1 = 4096
h2 = 32
num_layers = 1
dtype = torch.float16
device = torch.device("cuda:0")
problem_sizes = [2, 2]

w = [
      torch.ones((num_layers, h2, h1), dtype=dtype, device=device)
      for _ in range(len(problem_sizes))
  ]
w_ptr = torch.tensor([t.data_ptr() for t in w],
                     dtype=torch.int64,
                     device=device)
s = torch.cumsum(
    torch.tensor([0] + problem_sizes, device=device),
    dim=0,
    dtype=torch.int32)
x = torch.ones((s[-1], h1), dtype=dtype, device=device)
y = torch.zeros((s[-1], h2), dtype=dtype, device=device)
punica.ops.sgmv(y, x, w_ptr, s, layer_idx=0)

print(y)
harryhan618 commented 7 months ago

Thanks for your reply! I'm curious that why do you choose column major weight? My basic understanding is that row major is friendly for data loading. Sorry I haven't read the kernel code yet.

yzh119 commented 7 months ago

@harryhan618 modern GPUs support transpose at fragment level (with ldmatrix.***.trans/movmatrix instructions) at very low cost, so there should not be a significant performance difference between column major & row major layout.

We will support row-major for shrink kernel in the next release.

jcao-ai commented 7 months ago

@abcdabcd987 @yzh119 I also met the case that kernel launch fails under rank == 64 for sgmv_shrink usage:

import torch
import punica.ops

bs = 1
h1 = 1024
h2 = 64
num_layers = 32
dtype = torch.float16
device = torch.device("cuda:0")
problem_sizes = [1]

w = [
      torch.randn((num_layers, h2, h1), dtype=dtype, device=device)
      for _ in range(len(problem_sizes))
  ]

w_ptr = torch.tensor([t.data_ptr() for t in w],
                     dtype=torch.int64,
                     device=device)
s = torch.cumsum(
    torch.tensor([0] + problem_sizes, device=device),
    dim=0,
    dtype=torch.int32)
x = torch.ones((s[-1], h1), dtype=dtype, device=device)
y = torch.zeros((s[-1], h2), dtype=dtype, device=device)
# punica.ops.sgmv_cutlass(y, x, w_ptr, s, layer_idx=0)
punica.ops.sgmv(y, x, w_ptr, s, layer_idx=0)

print(y)

Output:

RuntimeError: No suitable kernel. dtype=Half d_out=64
jcao-ai commented 7 months ago

@abcdabcd987 @yzh119 I also met the case that kernel launch fails under rank == 64 for sgmv_shrink usage:

import torch
import punica.ops

bs = 1
h1 = 1024
h2 = 64
num_layers = 32
dtype = torch.float16
device = torch.device("cuda:0")
problem_sizes = [1]

w = [
      torch.randn((num_layers, h2, h1), dtype=dtype, device=device)
      for _ in range(len(problem_sizes))
  ]

w_ptr = torch.tensor([t.data_ptr() for t in w],
                     dtype=torch.int64,
                     device=device)
s = torch.cumsum(
    torch.tensor([0] + problem_sizes, device=device),
    dim=0,
    dtype=torch.int32)
x = torch.ones((s[-1], h1), dtype=dtype, device=device)
y = torch.zeros((s[-1], h2), dtype=dtype, device=device)
# punica.ops.sgmv_cutlass(y, x, w_ptr, s, layer_idx=0)
punica.ops.sgmv(y, x, w_ptr, s, layer_idx=0)

print(y)

Output:

RuntimeError: No suitable kernel. dtype=Half d_out=64

NVM, I found this is related to shared memory. PR: https://github.com/punica-ai/punica/pull/20

harryhan618 commented 7 months ago

Hi, any updates on why cutlass group gemmed calculate wrong results?

abcdabcd987 commented 7 months ago

Hi, any updates on why cutlass group gemmed calculate wrong results?

I just added a few test cases. https://github.com/punica-ai/punica/commit/0c7cf81b0238674be53d4b582ea3b926ae0f09a5

Cutlass only has this problem for shrink. Since we are deprecating cutlass shrink, we probably won't fix this. Before our custom expand lands, you can use punica.add_lora_sgmv_custom_cutlass() for LoRA.

harryhan618 commented 6 months ago

Hi lequn, I think I found the bug of cutlass_shrink.

Please first see cutlass example 24 group gemm. The second parameter for LinearCombination should 128 / cutlass::sizeof_bits<ElementOutput>::value. For dtype float16, this should be 8. (Although I don't know why this formula?)

In your code, for shrink, you wrote 4. I think this should be bug. For expand, you wrote 8, which is correct.

By the way, to make the code correctly compiled, I have to change Thread Block Shape and Warp Shape to be GemmShape<16, 128, 64> and GemmShape<16, 32, 64>.

So I'm also wondering how to choose these shape? Since that's the key difference between shrink and expand. I'm looking forward to see your insight!

Thank you!