Closed ruofan-wu closed 5 months ago
we currently do not have op implementation of int8xint8 with scaling, as the int8 dequantize usually do not have scaling on weights, for example, the bitnet forward implementation:
Hi @LeiWang1999,
How does int8xint8 matmul dequantize, could you give a PyTorch example code similar to the one below? Thanks very much!
rescaling_tensor = torch.zeros_like(weight_tensor, dtype=torch.float16).cuda()
# Compute reference result with manual scaling and zero-point adjustment
# rescale = (weight - zeros) * scaling
for i in range(in_features // group_size):
for j in range(group_size):
rescaling_tensor[:, i * group_size + j] = (
weight_tensor[:, i * group_size + j].to(torch.float16) - zeros[:, i]
) * scaling[:, i]
ref_result = torch.matmul(input_tensor, rescaling_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-2)
hi @GisellWu when both A_dtype and W_dtype are set to int8
, actually it doesn't involve dequantization, it just simply performs C[i, j] = Accum(A[i, k] * B[j, k], 'accum_dtype').astype('out_dtype').
In your case, the accum dtype is int32, the out_dtype is float16, this accumed result just simply cast instead of applying a scaling, we currently do not support apply scaling on the output in the kernel. So if you need scaling, we recommend implementing it using a simple PyTorch expression.
Since pytorch didn't support int8 matmul, you can use float32 to simulate.
ref_tensor = torch.matmul(input_tensor.to(torch.float32), weight_tensor.to(torch.float32).T)
fp16_output_tensor = ref_tensor.to(torch.float16)
This approach allows you to simulate the operation without native int8 support.
Hi @LeiWang1999,
Thanks for your explanation. I'd like to double-check that this implementation is specifically for pure int8 matmul without the quantization and dequantization processes, right?
yeah, absolutely.
@LeiWang1999 In the w4a8 setting in this repo, is per-token dynamic activation int8-quantization supported ?
Hi @brisker , We currently do not support scaling for int gemm within the kernel; instead, we handle this by rescaling the output tensor as shown here: https://github1s.com/microsoft/BitBLAS/blob/main/integration/BitNet/utils_quant.py. However, it is straightforward to integrate the rescale directly into the kernel by modifying the expression in: bitblas/ops/general_matmul/tirscript
@LeiWang1999 Thanks for your prompt reply! Besides, in this repo, what is the basic pipeline of w4a8?
Do you convert the int4-w into int8-w first, and then perform the w8a8 gemm(which can reuse the w8a8-gemm kernel)?
If this is true, if the w4 is per-channel quantized with group(which means for [cin,cout]
linear layer, there is a scale in shape
[cin/group_size,cout]
for w4), how is the w4a8 handled in bitblas kernel?
Currently, we do not implement int4 to int8 dequantization with group-wise scaling. Are you using smooth quantization?
@LeiWang1999 I want to accelerate w4a8-g128( Group size=128) quantization model.
So did you mean, only w4-perchannel quantization is supported in this repo?
@brisker yeah we currently do not implement group wise rescaling for int8xint4.
@LeiWang1999 for your per channel quantized w4a8, is w8a8 gemm kernel reused? how is that used?
Hi @brisker , you can analysis the cuda source through :
import bitblas
import torch
# enabling debug output
bitblas.set_log_level("Debug")
matmul_config = bitblas.MatmulConfig(
M=1024, # M dimension
N=1024, # N dimension
K=1024, # K dimension
A_dtype="int8", # activation A dtype
W_dtype="int4", # weight W dtype
accum_dtype="int32", # accumulation dtype
out_dtype="float32", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
fast_decoding=False, # setting for fast decoding
)
matmul = bitblas.Matmul(config=matmul_config)
print(matmul.get_source())
extern "C" __global__ void __launch_bounds__(128) main_kernel(signed char* __restrict__ A, signed char* __restrict__ B, float* __restrict__ D) {
int C_reindex_shared_warp[32];
__shared__ signed char A_reindex_reindex_shared[32768];
__shared__ signed char B_shared[4096];
__shared__ signed char B_decode_reindex_shared[4096];
signed char B_local[8];
signed char B_decode_reindex_local[16];
signed char A_reindex_reindex_shared_warp[64];
signed char B_decode_reindex_shared_warp[16];
signed char B_local_1[8];
signed char B_decode_reindex_local_1[16];
signed char A_reindex_reindex_shared_warp_1[64];
signed char B_decode_reindex_shared_warp_1[16];
for (int ax1_0_3_init = 0; ax1_0_3_init < 4; ++ax1_0_3_init) {
for (int i = 0; i < 8; ++i) {
C_reindex_shared_warp[(ax1_0_3_init * 8) + i] = 0.0;}
;
}
#pragma unroll
for (int ax0_ax1_ax2_ax3_ax4_fused_0 = 0; ax0_ax1_ax2_ax3_ax4_fused_0 < 8; ++ax0_ax1_ax2_ax3_ax4_fused_0) {
{
unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)(A_reindex_reindex_shared + ((((ax0_ax1_ax2_ax3_ax4_fused_0 * 2048) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16)))));
#else
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)(A_reindex_reindex_shared + ((((ax0_ax1_ax2_ax3_ax4_fused_0 * 2048) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16))))
);
#endif
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
:: "r"(addr), "l"((void*)(A + (((((((int)blockIdx.y) * 131072) + (ax0_ax1_ax2_ax3_ax4_fused_0 * 16384)) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16)))), "n"(16)
);
}
}
#pragma unroll
for (int ax0_ax1_ax2_ax3_fused_0_0_0_0 = 0; ax0_ax1_ax2_ax3_fused_0_0_0_0 < 1; ++ax0_ax1_ax2_ax3_fused_0_0_0_0) {
{
unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)(B_shared + (((((int)threadIdx.z) * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16)))));
#else
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)(B_shared + (((((int)threadIdx.z) * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16))))
);
#endif
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
:: "r"(addr), "l"((void*)(B + ((((((int)blockIdx.x) * 16384) + (((int)threadIdx.z) * 8192)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16)))), "n"(16)
);
}
}
__asm__ __volatile__("cp.async.commit_group;");
for (int ax3_0_0 = 0; ax3_0_0 < 7; ++ax3_0_0) {
__syncthreads();
#pragma unroll
for (int ax0_ax1_ax2_ax3_ax4_fused_0_1 = 0; ax0_ax1_ax2_ax3_ax4_fused_0_1 < 8; ++ax0_ax1_ax2_ax3_ax4_fused_0_1) {
{
unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)(A_reindex_reindex_shared + (((((((ax3_0_0 + 1) & 1) * 16384) + (ax0_ax1_ax2_ax3_ax4_fused_0_1 * 2048)) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16)))));
#else
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)(A_reindex_reindex_shared + (((((((ax3_0_0 + 1) & 1) * 16384) + (ax0_ax1_ax2_ax3_ax4_fused_0_1 * 2048)) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16))))
);
#endif
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
:: "r"(addr), "l"((void*)(A + (((((((((int)blockIdx.y) * 131072) + (ax0_ax1_ax2_ax3_ax4_fused_0_1 * 16384)) + (ax3_0_0 * 2048)) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16)) + 2048))), "n"(16)
);
}
}
#pragma unroll
for (int ax0_ax1_ax2_ax3_fused_0_0_0_0_1 = 0; ax0_ax1_ax2_ax3_fused_0_0_0_0_1 < 1; ++ax0_ax1_ax2_ax3_fused_0_0_0_0_1) {
{
unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)(B_shared + ((((((ax3_0_0 + 1) & 1) * 2048) + (((int)threadIdx.z) * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16)))));
#else
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)(B_shared + ((((((ax3_0_0 + 1) & 1) * 2048) + (((int)threadIdx.z) * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16))))
);
#endif
__asm__ __volatile__(
#if TVM_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
:: "r"(addr), "l"((void*)(B + ((((((((int)blockIdx.x) * 16384) + (((int)threadIdx.z) * 8192)) + (ax3_0_0 * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.x) * 16)) + 1024))), "n"(16)
);
}
}
__asm__ __volatile__("cp.async.commit_group;");
__asm__ __volatile__("cp.async.wait_group 1;");
__syncthreads();
for (int ax1_ax2_ax3_ax4_0_fused_0 = 0; ax1_ax2_ax3_ax4_0_fused_0 < 2; ++ax1_ax2_ax3_ax4_0_fused_0) {
*(int2*)(B_local + 0) = *(int2*)(B_shared + ((((((ax3_0_0 & 1) * 2048) + (ax1_ax2_ax3_ax4_0_fused_0 * 1024)) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)));
for (int ax4 = 0; ax4 < 16; ++ax4) {
B_decode_reindex_local[ax4] = (((signed char)((((uint)B_local[(ax4 >> 1)]) >> (((uint)(ax4 & 1)) * (uint)4)) & (uint)15)) - (signed char)8);
}
*(int4*)(B_decode_reindex_shared + ((((ax1_ax2_ax3_ax4_0_fused_0 * 2048) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16))) = *(int4*)(B_decode_reindex_local + 0);
}
__syncthreads();
for (int ax3_0_1 = 0; ax3_0_1 < 4; ++ax3_0_1) {
for (int ax1 = 0; ax1 < 4; ++ax1) {
{
unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)((&(A_reindex_reindex_shared[(((((ax3_0_0 & 1) * 16384) + (((int)threadIdx.y) * 8192)) + (ax1 * 2048)) + (ax3_0_1 * 512))])) + (((int)threadIdx.x) * 16))));
#else
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_reindex_reindex_shared[(((((ax3_0_0 & 1) * 16384) + (((int)threadIdx.y) * 8192)) + (ax1 * 2048)) + (ax3_0_1 * 512))])) + (((int)threadIdx.x) * 16)))
);
#endif
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1 * 16)))[0]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1 * 16)))[1]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1 * 16)))[2]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1 * 16)))[3])
: "r"(addr)
);
}
}
{
unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)((&(B_decode_reindex_shared[((((int)threadIdx.z) * 2048) + (ax3_0_1 * 512))])) + (((int)threadIdx.x) * 16))));
#else
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_decode_reindex_shared[((((int)threadIdx.z) * 2048) + (ax3_0_1 * 512))])) + (((int)threadIdx.x) * 16)))
);
#endif
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[0]), "=r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[1]), "=r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[2]), "=r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0_3 = 0; ax1_0_3 < 4; ++ax1_0_3) {
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[0]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[1]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[2]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[3])
: "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[0]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[1]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[2]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp + 0))[1]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[0]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[1]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[2]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[0]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[1]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[2]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[3])
: "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[0]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[1]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[2]), "r"(((unsigned *)(A_reindex_reindex_shared_warp + (ax1_0_3 * 16)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp + 8))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp + 8))[1]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[0]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[1]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[2]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3 * 8) + 4)))[3]));
}
}
}
}
__asm__ __volatile__("cp.async.wait_group 0;");
__syncthreads();
for (int ax1_ax2_ax3_ax4_0_fused_0_1 = 0; ax1_ax2_ax3_ax4_0_fused_0_1 < 2; ++ax1_ax2_ax3_ax4_0_fused_0_1) {
*(int2*)(B_local_1 + 0) = *(int2*)(B_shared + (((((ax1_ax2_ax3_ax4_0_fused_0_1 * 1024) + (((int)threadIdx.y) * 512)) + (((int)threadIdx.z) * 256)) + (((int)threadIdx.x) * 8)) + 2048));
for (int ax4_1 = 0; ax4_1 < 16; ++ax4_1) {
B_decode_reindex_local_1[ax4_1] = (((signed char)((((uint)B_local_1[(ax4_1 >> 1)]) >> (((uint)(ax4_1 & 1)) * (uint)4)) & (uint)15)) - (signed char)8);
}
*(int4*)(B_decode_reindex_shared + ((((ax1_ax2_ax3_ax4_0_fused_0_1 * 2048) + (((int)threadIdx.y) * 1024)) + (((int)threadIdx.z) * 512)) + (((int)threadIdx.x) * 16))) = *(int4*)(B_decode_reindex_local_1 + 0);
}
__syncthreads();
for (int ax3_0_1_1 = 0; ax3_0_1_1 < 4; ++ax3_0_1_1) {
for (int ax1_1 = 0; ax1_1 < 4; ++ax1_1) {
{
unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)((&(A_reindex_reindex_shared[((((((int)threadIdx.y) * 8192) + (ax1_1 * 2048)) + (ax3_0_1_1 * 512)) + 16384)])) + (((int)threadIdx.x) * 16))));
#else
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_reindex_reindex_shared[((((((int)threadIdx.y) * 8192) + (ax1_1 * 2048)) + (ax3_0_1_1 * 512)) + 16384)])) + (((int)threadIdx.x) * 16)))
);
#endif
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_1 * 16)))[0]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_1 * 16)))[1]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_1 * 16)))[2]), "=r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_1 * 16)))[3])
: "r"(addr)
);
}
}
{
unsigned int addr;
#if TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST
addr = static_cast<unsigned int>(__cvta_generic_to_shared((void *)((&(B_decode_reindex_shared[((((int)threadIdx.z) * 2048) + (ax3_0_1_1 * 512))])) + (((int)threadIdx.x) * 16))));
#else
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_decode_reindex_shared[((((int)threadIdx.z) * 2048) + (ax3_0_1_1 * 512))])) + (((int)threadIdx.x) * 16)))
);
#endif
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[0]), "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[1]), "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[2]), "=r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[3])
: "r"(addr)
);
}
for (int ax1_0_3_1 = 0; ax1_0_3_1 < 4; ++ax1_0_3_1) {
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[0]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[1]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[2]), "=r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[3])
: "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[0]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[1]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[2]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 0))[1]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[0]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[1]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[2]), "r"(((int *)(C_reindex_shared_warp + (ax1_0_3_1 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[0]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[1]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[2]), "=r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[3])
: "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[0]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[1]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[2]), "r"(((unsigned *)(A_reindex_reindex_shared_warp_1 + (ax1_0_3_1 * 16)))[3]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 8))[0]), "r"(((unsigned *)(B_decode_reindex_shared_warp_1 + 8))[1]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[0]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[1]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[2]), "r"(((int *)(C_reindex_shared_warp + ((ax1_0_3_1 * 8) + 4)))[3]));
}
}
}
for (int ax0 = 0; ax0 < 4; ++ax0) {
__syncthreads();
for (int local_id = 0; local_id < 8; ++local_id) {
(&(((int*)A_reindex_reindex_shared)[((((int)threadIdx.y) * 2048) + (((int)threadIdx.z) * 256))]))[((((((local_id % 4) / 2) * 8) + (threadIdx.x / 4)) * 16) + ((((local_id / 4) * 8) + ((threadIdx.x % 4) * 2)) + (local_id % 2)))] = C_reindex_shared_warp[(ax0 * 8) + local_id];
}
;
__syncthreads();
#pragma unroll
for (int ax0_ax1_ax2_ax3_ax4_fused_0_2 = 0; ax0_ax1_ax2_ax3_ax4_fused_0_2 < 2; ++ax0_ax1_ax2_ax3_ax4_fused_0_2) {
float4 __1;
int4 v_ = *(int4*)(((int*)A_reindex_reindex_shared) + ((((((int)threadIdx.y) * 2048) + (((int)threadIdx.z) * 256)) + (ax0_ax1_ax2_ax3_ax4_fused_0_2 * 128)) + (((int)threadIdx.x) * 4)));
__1.x = (float)(v_.x);
__1.y = (float)(v_.y);
__1.z = (float)(v_.z);
__1.w = (float)(v_.w);
*(float4*)(D + ((((((((((int)blockIdx.y) * 131072) + (((int)threadIdx.y) * 65536)) + (ax0 * 16384)) + (ax0_ax1_ax2_ax3_ax4_fused_0_2 * 8192)) + ((((int)threadIdx.x) >> 2) * 1024)) + (((int)blockIdx.x) * 32)) + (((int)threadIdx.z) * 16)) + ((((int)threadIdx.x) & 3) * 4))) = __1;
}
}
}
@LeiWang1999 You answer is a little complex.
I just want to know, whether your w4a8 pipeline is simply multiplying the w4 by 16 to convert w4 into w8, and then perform w8a8 gemm.
@brisker , Nop, when converting W4 to W8, it simply maintains the W4 precision; however, appending a shift rescaling in BitBLAS is quite straightforward and simple, but currently we do not meet this situation so it's not implemented yet.
@LeiWang1999
So did you mean, in your w4a8 kernel calculation, it is just [-7,7]
int4 numbers doing gemm with [-127,127]
int8 numbers?
Will that be faster than w4-multiply-16->w8 and then do w8a8 gemm pipeline(after gemm, div 16)?
For the first question, definitely. Regarding the second question, in my experience, there is no difference between the two implementations if the rescaling is done within the kernel. However, we do support int8xint2 for the 1.58-bit model, which uses a floating-point rescaling factor rather than a fixed x16 scaling, so we didn't implement for this case.
@LeiWang1999
Regarding the first question, I do not quite understand that, many gemm on gpus require the weight and activation be the same dtype, so many w4a8 pipelines in other frameworks(for example, this one ) first convert w4 into w8, and then do w8a8 gemm to fulfill the real w4a8 pipeline.
So why can bitblas directly do w4a8 gemm, without w4->w8 conversion??
Does that mean your w4a8 will be a lot faster than others(like this one )?
Hi, thanks for the great work!
Does BitBLAS currently support int8xint8 matmul with scaling? I tried the following code and found some config results but got a segmentation fault finally.
The outputs are: