microsoft / BitBLAS

BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.
MIT License
359 stars 29 forks source link

Support for int8xint8 matmul with scaling #28

Closed ruofan-wu closed 5 months ago

ruofan-wu commented 5 months ago

Hi, thanks for the great work!

Does BitBLAS currently support int8xint8 matmul with scaling? I tried the following code and found some config results but got a segmentation fault finally.

import bitblas
import torch

bitblas.set_log_level("DEBUG")

in_features = 4096
out_features = 4096

matmul_config = bitblas.MatmulConfig(
    M=1,
    N=out_features,
    K=in_features,
    A_dtype="int8",
    W_dtype="int8",
    accum_dtype="int32",
    out_dtype="float16",
    layout="nt",
    with_bias=False,
    group_size=None,
    with_scaling=True,
    with_zeros=False,
    zeros_mode=None,
)
matmul = bitblas.Matmul(config=matmul_config)

input_shape = (1, 4096)
weight_shape = (4096, 4096)
scaling_shape = (4096, 1)
output_shape = (1, 4096)

scaling = torch.rand(scaling_shape, dtype=torch.float16).cuda()
input_tensor = torch.randint(0, 7, input_shape, dtype=torch.int8).cuda()
weight_tensor = torch.randint(0, 7, weight_shape, dtype=torch.int8).cuda()

output_tensor = matmul(input_tensor, weight_tensor, scale=scaling)

The outputs are:

2024-04-28 03:58:24 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-04-28 03:58:39 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-04-28 03:58:39 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [4096], 'reduce_thread': [64], 'vectorize': {'A': 16, 'B': 16}}
2024-04-28 03:58:39 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [4096], 'reduce_thread': [32], 'vectorize': {'A': 16, 'B': 16}}
2024-04-28 03:58:39 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [1024], 'reduce_thread': [8], 'vectorize': {'A': 8, 'B': 16}}
2024-04-28 03:58:39 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [4096], 'reduce_thread': [128], 'vectorize': {'A': 16, 'B': 16}}
2024-04-28 03:58:39 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [2048], 'reduce_thread': [16], 'vectorize': {'A': 16, 'B': 16}}
2024-04-28 03:58:39 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [512], 'reduce_thread': [4], 'vectorize': {'A': 4, 'B': 16}}
2024-04-28 03:58:39 [BitBLAS:DEBUG]: Apply config {'block': [64], 'thread': [64], 'rstep': [256], 'reduce_thread': [2], 'vectorize': {'A': 2, 'B': 16}}
2024-04-28 03:58:39 [BitBLAS:DEBUG]: Apply config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B': 16}}
2024-04-28 03:58:39 [BitBLAS:DEBUG]: Apply config {'block': [256], 'thread': [128], 'rstep': [128], 'vectorize': {'B': 16}}
2024-04-28 03:58:43 [BitBLAS:INFO]: Evaluation with config {'block': [2], 'thread': [2], 'rstep': [4096], 'reduce_thread': [64], 'vectorize': {'A': 16, 'B': 16}}
2024-04-28 03:58:43 [BitBLAS:INFO]: Time cost of this config: 0.015 ms
2024-04-28 03:58:43 [BitBLAS:INFO]: Evaluation with config {'block': [1], 'thread': [1], 'rstep': [4096], 'reduce_thread': [128], 'vectorize': {'A': 16, 'B': 16}}
2024-04-28 03:58:43 [BitBLAS:INFO]: Time cost of this config: 0.015 ms
2024-04-28 03:58:43 [BitBLAS:INFO]: Evaluation with config {'block': [4], 'thread': [4], 'rstep': [4096], 'reduce_thread': [32], 'vectorize': {'A': 16, 'B': 16}}
2024-04-28 03:58:43 [BitBLAS:INFO]: Time cost of this config: 0.017 ms
2024-04-28 03:58:43 [BitBLAS:INFO]: Evaluation with config {'block': [16], 'thread': [16], 'rstep': [1024], 'reduce_thread': [8], 'vectorize': {'A': 8, 'B': 16}}
2024-04-28 03:58:43 [BitBLAS:INFO]: Time cost of this config: 0.018 ms
2024-04-28 03:58:43 [BitBLAS:INFO]: Evaluation with config {'block': [8], 'thread': [8], 'rstep': [2048], 'reduce_thread': [16], 'vectorize': {'A': 16, 'B': 16}}
2024-04-28 03:58:43 [BitBLAS:INFO]: Time cost of this config: 0.468 ms
2024-04-28 03:58:43 [BitBLAS:INFO]: Evaluation with config {'block': [32], 'thread': [32], 'rstep': [512], 'reduce_thread': [4], 'vectorize': {'A': 4, 'B': 16}}
2024-04-28 03:58:43 [BitBLAS:INFO]: Time cost of this config: 0.015 ms
2024-04-28 03:58:43 [BitBLAS:INFO]: Evaluation with config {'block': [64], 'thread': [64], 'rstep': [256], 'reduce_thread': [2], 'vectorize': {'A': 2, 'B': 16}}
2024-04-28 03:58:43 [BitBLAS:INFO]: Time cost of this config: 0.022 ms
2024-04-28 03:58:43 [BitBLAS:INFO]: Evaluation with config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B': 16}}
2024-04-28 03:58:43 [BitBLAS:INFO]: Time cost of this config: 0.029 ms
2024-04-28 03:58:43 [BitBLAS:INFO]: Evaluation with config {'block': [256], 'thread': [128], 'rstep': [128], 'vectorize': {'B': 16}}
2024-04-28 03:58:43 [BitBLAS:INFO]: Time cost of this config: 0.893 ms
Segmentation fault (core dumped)
LeiWang1999 commented 5 months ago

we currently do not have op implementation of int8xint8 with scaling, as the int8 dequantize usually do not have scaling on weights, for example, the bitnet forward implementation:

https://github.com/microsoft/BitBLAS/blob/d536ddea210d5c0a97dfb55b4630d944421d13e2/integration/BitNet/utils_quant.py#L141-L153

ruofan-wu commented 5 months ago

Hi @LeiWang1999,

How does int8xint8 matmul dequantize, could you give a PyTorch example code similar to the one below? Thanks very much!

rescaling_tensor = torch.zeros_like(weight_tensor, dtype=torch.float16).cuda()
# Compute reference result with manual scaling and zero-point adjustment
# rescale = (weight - zeros) * scaling
for i in range(in_features // group_size):
    for j in range(group_size):
        rescaling_tensor[:, i * group_size + j] = (
            weight_tensor[:, i * group_size + j].to(torch.float16) - zeros[:, i]
        ) * scaling[:, i]
ref_result = torch.matmul(input_tensor, rescaling_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-2)
LeiWang1999 commented 5 months ago

hi @GisellWu when both A_dtype and W_dtype are set to int8, actually it doesn't involve dequantization, it just simply performs C[i, j] = Accum(A[i, k] * B[j, k], 'accum_dtype').astype('out_dtype').

In your case, the accum dtype is int32, the out_dtype is float16, this accumed result just simply cast instead of applying a scaling, we currently do not support apply scaling on the output in the kernel. So if you need scaling, we recommend implementing it using a simple PyTorch expression.

Since pytorch didn't support int8 matmul, you can use float32 to simulate.

ref_tensor = torch.matmul(input_tensor.to(torch.float32), weight_tensor.to(torch.float32).T)
fp16_output_tensor = ref_tensor.to(torch.float16)

This approach allows you to simulate the operation without native int8 support.

ruofan-wu commented 5 months ago

Hi @LeiWang1999,

Thanks for your explanation. I'd like to double-check that this implementation is specifically for pure int8 matmul without the quantization and dequantization processes, right?

LeiWang1999 commented 5 months ago

yeah, absolutely.

brisker commented 2 months ago

@LeiWang1999 In the w4a8 setting in this repo, is per-token dynamic activation int8-quantization supported ?

image

LeiWang1999 commented 2 months ago

Hi @brisker , We currently do not support scaling for int gemm within the kernel; instead, we handle this by rescaling the output tensor as shown here: https://github1s.com/microsoft/BitBLAS/blob/main/integration/BitNet/utils_quant.py. However, it is straightforward to integrate the rescale directly into the kernel by modifying the expression in: bitblas/ops/general_matmul/tirscript

brisker commented 2 months ago

@LeiWang1999 Thanks for your prompt reply! Besides, in this repo, what is the basic pipeline of w4a8?

Do you convert the int4-w into int8-w first, and then perform the w8a8 gemm(which can reuse the w8a8-gemm kernel)?

If this is true, if the w4 is per-channel quantized with group(which means for [cin,cout] linear layer, there is a scale in shape [cin/group_size,cout] for w4), how is the w4a8 handled in bitblas kernel?

LeiWang1999 commented 2 months ago

Currently, we do not implement int4 to int8 dequantization with group-wise scaling. Are you using smooth quantization?

brisker commented 2 months ago

@LeiWang1999 I want to accelerate w4a8-g128( Group size=128) quantization model.

So did you mean, only w4-perchannel quantization is supported in this repo?

LeiWang1999 commented 2 months ago

@brisker yeah we currently do not implement group wise rescaling for int8xint4.

brisker commented 2 months ago

@LeiWang1999 for your per channel quantized w4a8, is w8a8 gemm kernel reused? how is that used?

LeiWang1999 commented 2 months ago

Hi @brisker , you can analysis the cuda source through :

import bitblas
import torch

# enabling debug output

bitblas.set_log_level("Debug")
matmul_config = bitblas.MatmulConfig(
    M=1024,  # M dimension
    N=1024,  # N dimension
    K=1024,  # K dimension
    A_dtype="int8",  # activation A dtype
    W_dtype="int4",  # weight W dtype
    accum_dtype="int32",  # accumulation dtype
    out_dtype="float32",  # output dtype
    layout="nt",  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    with_bias=False,  # bias
    # configs for weight only quantization
    group_size=None,  # setting for grouped quantization
    with_scaling=False,  # setting for scaling factor
    with_zeros=False,  # setting for zeros
    zeros_mode=None,  # setting for how to calculating zeros
    fast_decoding=False,  # setting for fast decoding
)

matmul = bitblas.Matmul(config=matmul_config)

print(matmul.get_source())
extern "C" __global__ void __launch_bounds__(128) main_kernel(signed char* __restrict__ A, signed char* __restrict__ B, float* __restrict__ D) {
  int C_reindex_shared_warp[32];
  __shared__ signed char A_reindex_reindex_shared[32768];
  __shared__ signed char B_shared[4096];
  __shared__ signed char B_decode_reindex_shared[4096];
  signed char B_local[8];
  signed char B_decode_reindex_local[16];
  signed char A_reindex_reindex_shared_warp[64];
  signed char B_decode_reindex_shared_warp[16];
  signed char B_local_1[8];
  signed char B_decode_reindex_local_1[16];
  signed char A_reindex_reindex_shared_warp_1[64];
  signed char B_decode_reindex_shared_warp_1[16];
  for (int ax1_0_3_init = 0; ax1_0_3_init < 4; ++ax1_0_3_init) {
    for (int i = 0; i < 8; ++i) {
C_reindex_shared_warp[(ax1_0_3_init * 8) + i] = 0.0;}
;
  }
  #pragma unroll
  for (int ax0_ax1_ax2_ax3_ax4_fused_0 = 0; ax0_ax1_ax2_ax3_ax4_fused_0 < 8; ++ax0_ax1_ax2_ax3_ax4_fused_0) {

  {
        unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
    addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)(A_reindex_reindex_shared + ((((ax0_ax1_ax2_ax3_ax4_fused_0 * 2048) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16)))));
#else
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
      : "=r"(addr)
      : "l"((void *)(A_reindex_reindex_shared + ((((ax0_ax1_ax2_ax3_ax4_fused_0 * 2048) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16))))
    );
#endif
    __asm__ __volatile__(
      #if TVM_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
      #else
        "cp.async.cg.shared.global [%0], [%1], %2;"
      #endif
        :: "r"(addr), "l"((void*)(A + (((((((int)blockIdx.y) * 131072) + (ax0_ax1_ax2_ax3_ax4_fused_0 * 16384)) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16)))), "n"(16)
    );
  }
  }
  #pragma unroll
  for (int ax0_ax1_ax2_ax3_fused_0_0_0_0 = 0; ax0_ax1_ax2_ax3_fused_0_0_0_0 < 1; ++ax0_ax1_ax2_ax3_fused_0_0_0_0) {

  {
        unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
    addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)(B_shared + (((((int)threadIdx.z) * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16)))));
#else
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
      : "=r"(addr)
      : "l"((void *)(B_shared + (((((int)threadIdx.z) * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16))))
    );
#endif
    __asm__ __volatile__(
      #if TVM_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
      #else
        "cp.async.cg.shared.global [%0], [%1], %2;"
      #endif
        :: "r"(addr), "l"((void*)(B + ((((((int)blockIdx.x) * 16384) + (((int)threadIdx.z) * 8192)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16)))), "n"(16)
    );
  }
  }
__asm__ __volatile__("cp.async.commit_group;");

  for (int ax3_0_0 = 0; ax3_0_0 < 7; ++ax3_0_0) {
    __syncthreads();
    #pragma unroll
    for (int ax0_ax1_ax2_ax3_ax4_fused_0_1 = 0; ax0_ax1_ax2_ax3_ax4_fused_0_1 < 8; ++ax0_ax1_ax2_ax3_ax4_fused_0_1) {

  {
        unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
    addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)(A_reindex_reindex_shared + (((((((ax3_0_0 + 1) & 1) * 16384) + (ax0_ax1_ax2_ax3_ax4_fused_0_1 * 2048)) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16)))));
#else
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
      : "=r"(addr)
      : "l"((void *)(A_reindex_reindex_shared + (((((((ax3_0_0 + 1) & 1) * 16384) + (ax0_ax1_ax2_ax3_ax4_fused_0_1 * 2048)) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16))))
    );
#endif
    __asm__ __volatile__(
      #if TVM_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
      #else
        "cp.async.cg.shared.global [%0], [%1], %2;"
      #endif
        :: "r"(addr), "l"((void*)(A + (((((((((int)blockIdx.y) * 131072) + (ax0_ax1_ax2_ax3_ax4_fused_0_1 * 16384)) + (ax3_0_0 * 2048)) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16)) + 2048))), "n"(16)
    );
  }
    }
    #pragma unroll
    for (int ax0_ax1_ax2_ax3_fused_0_0_0_0_1 = 0; ax0_ax1_ax2_ax3_fused_0_0_0_0_1 < 1; ++ax0_ax1_ax2_ax3_fused_0_0_0_0_1) {

  {
        unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
    addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)(B_shared + ((((((ax3_0_0 + 1) & 1) * 2048) + (((int)threadIdx.z) * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16)))));
#else
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
      : "=r"(addr)
      : "l"((void *)(B_shared + ((((((ax3_0_0 + 1) & 1) * 2048) + (((int)threadIdx.z) * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16))))
    );
#endif
    __asm__ __volatile__(
      #if TVM_ENABLE_L2_PREFETCH
        "cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
      #else
        "cp.async.cg.shared.global [%0], [%1], %2;"
      #endif
        :: "r"(addr), "l"((void*)(B + ((((((((int)blockIdx.x) * 16384) + (((int)threadIdx.z) * 8192)) + (ax3_0_0 * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16)) + 1024))), "n"(16)
    );
  }
    }
__asm__ __volatile__("cp.async.commit_group;");

__asm__ __volatile__("cp.async.wait_group 1;");

    __syncthreads();
    for (int ax1_ax2_ax3_ax4_0_fused_0 = 0; ax1_ax2_ax3_ax4_0_fused_0 < 2; ++ax1_ax2_ax3_ax4_0_fused_0) {
      *(int2*)(B_local + 0) = *(int2*)(B_shared + ((((((ax3_0_0 & 1) * 2048) + (ax1_ax2_ax3_ax4_0_fused_0 * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)));
      for (int ax4 = 0; ax4 < 16; ++ax4) {
        B_decode_reindex_local[ax4] = (((signed char)((((uint)B_local[(ax4 >> 1)]) >> (((uint)(ax4 & 1)) * (uint)4)) & (uint)15)) - (signed char)8);
      }
      *(int4*)(B_decode_reindex_shared + ((((ax1_ax2_ax3_ax4_0_fused_0 * 2048) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16))) = *(int4*)(B_decode_reindex_local + 0);
    }
    __syncthreads();
    for (int ax3_0_1 = 0; ax3_0_1 < 4; ++ax3_0_1) {
      for (int ax1 = 0; ax1 < 4; ++ax1) {

  {
    unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
    addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)((&(A_reindex_reindex_shared[(((((ax3_0_0 & 1) * 16384) + (((int)threadIdx.y) * 8192)) + (ax1 * 2048)) + (ax3_0_1 * 512))])) + (((int)threadIdx.x) * 16))));
#else
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
      : "=r"(addr)
      : "l"((void *)((&(A_reindex_reindex_shared[(((((ax3_0_0 & 1) * 16384) + (((int)threadIdx.y) * 8192)) + (ax1 * 2048)) + (ax3_0_1 * 512))])) + (((int)threadIdx.x) * 16)))
    );
#endif
    __asm__ __volatile__(
      "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
      "{%0, %1, %2, %3}, [%4];\n"
      : "=r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1 * 16)))[0]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1 * 16)))[1]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1 * 16)))[2]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1 * 16)))[3])
      : "r"(addr)
    );
  }
      }

  {
    unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
    addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)((&(B_decode_reindex_shared[((((int)threadIdx.z) * 2048) + (ax3_0_1 * 512))])) + (((int)threadIdx.x) * 16))));
#else
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
      : "=r"(addr)
      : "l"((void *)((&(B_decode_reindex_shared[((((int)threadIdx.z) * 2048) + (ax3_0_1 * 512))])) + (((int)threadIdx.x) * 16)))
    );
#endif
    __asm__ __volatile__(
      "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
      "{%0, %1, %2, %3}, [%4];\n"
      : "=r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[0]), "=r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[1]), "=r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[2]), "=r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[3])
      : "r"(addr)
    );
  }
      for (int ax1_0_3 = 0; ax1_0_3 < 4; ++ax1_0_3) {

  {
    __asm__ __volatile__(
      "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
      "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
      :  "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[0]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[1]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[2]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[3])
      : "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[0]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[1]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[2]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[1]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[0]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[1]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[2]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[3]));
  }

  {
    __asm__ __volatile__(
      "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
      "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
      :  "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[0]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[1]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[2]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[3])
      : "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[0]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[1]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[2]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp + 8))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp + 8))[1]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[0]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[1]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[2]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[3]));
  }
      }
    }
  }
__asm__ __volatile__("cp.async.wait_group 0;");

  __syncthreads();
  for (int ax1_ax2_ax3_ax4_0_fused_0_1 = 0; ax1_ax2_ax3_ax4_0_fused_0_1 < 2; ++ax1_ax2_ax3_ax4_0_fused_0_1) {
    *(int2*)(B_local_1 + 0) = *(int2*)(B_shared + (((((ax1_ax2_ax3_ax4_0_fused_0_1 * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)) + 2048));
    for (int ax4_1 = 0; ax4_1 < 16; ++ax4_1) {
      B_decode_reindex_local_1[ax4_1] = (((signed char)((((uint)B_local_1[(ax4_1 >> 1)]) >> (((uint)(ax4_1 & 1)) * (uint)4)) & (uint)15)) - (signed char)8);
    }
    *(int4*)(B_decode_reindex_shared + ((((ax1_ax2_ax3_ax4_0_fused_0_1 * 2048) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16))) = *(int4*)(B_decode_reindex_local_1 + 0);
  }
  __syncthreads();
  for (int ax3_0_1_1 = 0; ax3_0_1_1 < 4; ++ax3_0_1_1) {
    for (int ax1_1 = 0; ax1_1 < 4; ++ax1_1) {

  {
    unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
    addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)((&(A_reindex_reindex_shared[((((((int)threadIdx.y) * 8192) + (ax1_1 * 2048)) + (ax3_0_1_1 * 512)) + 16384)])) + (((int)threadIdx.x) * 16))));
#else
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
      : "=r"(addr)
      : "l"((void *)((&(A_reindex_reindex_shared[((((((int)threadIdx.y) * 8192) + (ax1_1 * 2048)) + (ax3_0_1_1 * 512)) + 16384)])) + (((int)threadIdx.x) * 16)))
    );
#endif
    __asm__ __volatile__(
      "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
      "{%0, %1, %2, %3}, [%4];\n"
      : "=r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_1 * 16)))[0]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_1 * 16)))[1]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_1 * 16)))[2]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_1 * 16)))[3])
      : "r"(addr)
    );
  }
    }

  {
    unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
    addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)((&(B_decode_reindex_shared[((((int)threadIdx.z) * 2048) + (ax3_0_1_1 * 512))])) + (((int)threadIdx.x) * 16))));
#else
    __asm__ __volatile__(
      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
      : "=r"(addr)
      : "l"((void *)((&(B_decode_reindex_shared[((((int)threadIdx.z) * 2048) + (ax3_0_1_1 * 512))])) + (((int)threadIdx.x) * 16)))
    );
#endif
    __asm__ __volatile__(
      "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
      "{%0, %1, %2, %3}, [%4];\n"
      : "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[0]), "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[1]), "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[2]), "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[3])
      : "r"(addr)
    );
  }
    for (int ax1_0_3_1 = 0; ax1_0_3_1 < 4; ++ax1_0_3_1) {

  {
    __asm__ __volatile__(
      "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
      "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
      :  "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[0]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[1]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[2]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[3])
      : "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[0]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[1]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[2]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[1]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[0]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[1]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[2]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[3]));
  }

  {
    __asm__ __volatile__(
      "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
      "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
      :  "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[0]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[1]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[2]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[3])
      : "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[0]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[1]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[2]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 8))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 8))[1]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[0]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[1]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[2]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[3]));
  }
    }
  }
  for (int ax0 = 0; ax0 < 4; ++ax0) {
    __syncthreads();
    for (int local_id = 0; local_id < 8; ++local_id) {
(&(((int*)A_reindex_reindex_shared)[((((int)threadIdx.y) * 2048) + (((int)threadIdx.z) * 256))]))[((((((local_id % 4) / 2) * 8) + (threadIdx.x / 4)) * 16) + ((((local_id / 4) * 8) + ((threadIdx.x % 4) * 2)) + (local_id % 2)))] = C_reindex_shared_warp[(ax0 * 8) + local_id];
}
;
    __syncthreads();
    #pragma unroll
    for (int ax0_ax1_ax2_ax3_ax4_fused_0_2 = 0; ax0_ax1_ax2_ax3_ax4_fused_0_2 < 2; ++ax0_ax1_ax2_ax3_ax4_fused_0_2) {
      float4 __1;
      int4 v_ = *(int4*)(((int*)A_reindex_reindex_shared) + ((((((int)threadIdx.y) * 2048) + (((int)threadIdx.z) * 256)) + (ax0_ax1_ax2_ax3_ax4_fused_0_2 * 128)) + (((int)threadIdx.x) * 4)));
      __1.x = (float)(v_.x);
      __1.y = (float)(v_.y);
      __1.z = (float)(v_.z);
      __1.w = (float)(v_.w);
      *(float4*)(D + ((((((((((int)blockIdx.y) * 131072) + (((int)threadIdx.y) * 65536)) + (ax0 * 16384)) + (ax0_ax1_ax2_ax3_ax4_fused_0_2 * 8192)) + ((((int)threadIdx.x) >> 2) * 1024)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16)) + ((((int)threadIdx.x) & 3) * 4))) = __1;
    }
  }
}
brisker commented 2 months ago

@LeiWang1999 You answer is a little complex.

I just want to know, whether your w4a8 pipeline is simply multiplying the w4 by 16 to convert w4 into w8, and then perform w8a8 gemm.

LeiWang1999 commented 2 months ago

@brisker , Nop, when converting W4 to W8, it simply maintains the W4 precision; however, appending a shift rescaling in BitBLAS is quite straightforward and simple, but currently we do not meet this situation so it's not implemented yet.

brisker commented 2 months ago

@LeiWang1999

  1. So did you mean, in your w4a8 kernel calculation, it is just [-7,7] int4 numbers doing gemm with [-127,127] int8 numbers?

  2. Will that be faster than w4-multiply-16->w8 and then do w8a8 gemm pipeline(after gemm, div 16)?

LeiWang1999 commented 2 months ago

For the first question, definitely. Regarding the second question, in my experience, there is no difference between the two implementations if the rescaling is done within the kernel. However, we do support int8xint2 for the 1.58-bit model, which uses a floating-point rescaling factor rather than a fixed x16 scaling, so we didn't implement for this case.

brisker commented 1 month ago

@LeiWang1999

Regarding the first question, I do not quite understand that, many gemm on gpus require the weight and activation be the same dtype, so many w4a8 pipelines in other frameworks(for example, this one ) first convert w4 into w8, and then do w8a8 gemm to fulfill the real w4a8 pipeline.

So why can bitblas directly do w4a8 gemm, without w4->w8 conversion??

Does that mean your w4a8 will be a lot faster than others(like this one )?