microsoft / BitBLAS

BitBLAS is a library to support mixed-precision matrix multiplications, especially for quantized LLM deployment.
MIT License
190 stars 21 forks source link

Is int1 x float16 supported? #35

Closed chromecast56 closed 1 month ago

chromecast56 commented 1 month ago

Thanks for the great project - I was wondering if the repo supports int1 x float16 matmul, and if so how should I go about it? For reference, I'm trying to strengthen the kernel in this repo.

My attempt:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import bitblas
import torch

os.environ["PATH"] = os.environ["PATH"]+":/usr/local/cuda/bin/"

bitblas.set_log_level("DEBUG")
matmul_config = bitblas.MatmulConfig(
    M=1,  # M dimension
    N=1024,  # N dimension
    K=1024,  # K dimension
    A_dtype="float16",  # activation A dtype
    W_dtype="int1",  # weight W dtype
    accum_dtype="float16",  # accumulation dtype
    out_dtype="float16",  # output dtype
    layout="nt",  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    with_bias=False,  # bias
    # configs for weight only quantization
    group_size=None,  # setting for grouped quantization
    with_scaling=False,  # setting for scaling factor
    with_zeros=False,  # setting for zeros
    zeros_mode=None,  # setting for how to calculating zeros
)

matmul = bitblas.Matmul(config=matmul_config)

# Create input matrices
input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 7, (1024, 1024), dtype=torch.int8).cuda()

# Transform weight tensor to int4 data type
weight_tensor_int4 = matmul.transform_weight(weight_tensor)

# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int4)

# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)

Output:

2024-05-03 11:57:37 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-05-03 11:57:37 [BitBLAS:DEBUG]: [BitBLAS][Error] applying rule <bitblas.gpu.gemv.GEMV object at 0x7ff5038880d0> failed
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [1024], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [1024], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [1024], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}
2024-05-03 11:57:40 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
...
ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.
LeiWang1999 commented 1 month ago

@chromecast56 Thanks for your attention!

yeah absolutely we support float16xint1/uint1, where int1 represents -1 and 1, uint1 represents 0 and 1.

In this case, the ValueError indeed not a real issue, as certain tile configurations may not be suitable for this particular shape.

chromecast56 commented 1 month ago

Can you elaborate on the ValueError/what tile configurations are not suitable? I also tried adapting the benchmarking script here. It works for lots of W_dtype, including uint1, but for some reason W_dtype="int1" results in the same TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_scale_' is not registered ValueError. Although, setting A_dtype = accum_dtype = out_dtype = 'int8' with W_dtype='int1' works.

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas.utils.target_detector import auto_detect_nvidia_target
from bitblas import Matmul, MatmulConfig

# Assign arguments to variables  
target = auto_detect_nvidia_target() 
group_size = None 
A_dtype = "float16" 
W_dtype = "int1" 
accum_dtype = "float16"
out_dtype = "float16"
layout = "nt"  # "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias = False   
with_scaling = True 
with_zeros = False
zeros_mode = None

test_shapes = [
    # square test
    (MatmulConfig, Matmul, (1, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # # BLOOM-176B
    # (MatmulConfig, Matmul, (1, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (1, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (1, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (1, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # # # OPT-65B
    # (MatmulConfig, Matmul, (1, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (1, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (1, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (1, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # # # LLAMA-70B/65B
    # (MatmulConfig, Matmul, (1, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (1, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (1, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (1, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),

    # # square test
    # (MatmulConfig, Matmul, (16384, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # # BLOOM-176B
    # (MatmulConfig, Matmul, (8192, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (8192, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (8192, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (8192, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # # # OPT-65B
    # (MatmulConfig, Matmul, (8192, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (8192, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (8192, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (8192, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # # # LLAMA-70B/65B
    # (MatmulConfig, Matmul, (8192, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (8192, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (8192, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),
    # (MatmulConfig, Matmul, (8192, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode)),

]

benchmark_sets = []
benchmark_sets.extend(test_shapes)

# fmt:on

benchmark_results = {}
for config, operator, input_args in benchmark_sets:
    config = config(*input_args)
    matmul = operator(config, target=target, enable_tuning=True)
    kernel_latency = matmul.profile_latency()
    if matmul.input_transform is not None:
        kernel_latency += matmul.ladder_permutate_a.profile_latency()

    print("Time cost is: {:.3f} ms".format(kernel_latency))

    profile_config = {
        f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
            "BitBLAS_top20_latency": kernel_latency,
        }
    }

    benchmark_results.update(profile_config)

# Define headers for the table  
headers = [
    "PrimFunc",
    "Input Arguments",
    "BitBLAS Top20 Latency",
]

col_widths = [0, 0, 0]
for config, values in benchmark_results.items():
    args = config.split("-")
    func_name = args[0]
    input_args = "-".join(args[1:])
    col_widths[0] = max((max(len(str(headers[0])), len(func_name)) + 2), col_widths[0])
    col_widths[1] = max((max(len(str(headers[1])), len(input_args)) + 2, col_widths[1]))
    col_widths[2] = max(max(len(str(headers[2])), len(f"{values['BitBLAS_top20_latency']:.3f} ms")) + 2, col_widths[2])
    break

for i, header in enumerate(headers):
    headers[i] = header.ljust(col_widths[i])

print("".join(headers))

print("-" * sum(col_widths))

for config, values in benchmark_results.items():
    args = config.split("-")
    func_name = args[0]
    input_args = "-".join(args[1:])
    row = [
        func_name,
        input_args,
        f"{values['BitBLAS_top20_latency']:.3f} ms",
    ]
    print("".join([str(i).ljust(col_widths[j]) for j, i in enumerate(row)]) + "\n")
LeiWang1999 commented 1 month ago

I see, you set with_scaling=True, currently we do not support int8 related computation with scaling (because the option with scaling means do value rescaling on weights, the int8xint1/int2/.. do scaling on their outputs to rescale to float32/float16)

chromecast56 commented 1 month ago

I set with_scaling=False and I get the same error:

2024-05-04 23:29:09 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-05-04 23:29:09 [BitBLAS:DEBUG]: [BitBLAS][Error] applying rule <bitblas.gpu.gemv.GEMV object at 0x7f8083a66490> failed
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Cannot find the appropriate index map for tensorcore
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply config {'block': [2], 'thread': [2], 'rstep': [4096], 'reduce_thread': [64], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply config {'block': [1], 'thread': [1], 'rstep': [4096], 'reduce_thread': [128], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply config {'block': [8], 'thread': [8], 'rstep': [1024], 'reduce_thread': [16], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply config {'block': [4], 'thread': [4], 'rstep': [2048], 'reduce_thread': [32], 'vectorize': {'A': 8, 'B_decode': 8}}
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply config {'block': [16], 'thread': [16], 'rstep': [512], 'reduce_thread': [8], 'vectorize': {'A': 4, 'B_decode': 8}}
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply config {'block': [64], 'thread': [64], 'rstep': [128], 'reduce_thread': [2], 'vectorize': {'B_decode': 8}}
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply config {'block': [32], 'thread': [32], 'rstep': [256], 'reduce_thread': [4], 'vectorize': {'A': 2, 'B_decode': 8}}
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply config {'block': [128], 'thread': [128], 'rstep': [128], 'vectorize': {'B_decode': 8}}
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Warning: block config [128] is not valid for matmul, skip.
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-04 23:29:11 [BitBLAS:DEBUG]: Apply schedule failed: Traceback (most recent call last):
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)>::AssignTypedLambda<tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}>(tvm::tir::{lambda(tvm::tir::Schedule, tvm::runtime::ObjectRef, tvm::runtime::String, bool)#14}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::tir::TracedScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  1: tvm::tir::ConcreteScheduleNode::Tensorize(tvm::tir::LoopRV const&, tvm::runtime::String const&, bool)
  0: tvm::tir::TensorIntrin::Get(tvm::runtime::String, bool)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/tir/ir/function.cc", line 151
ValueError: TensorIntrin 'lop3_fast_decode_i1_to_int8_to_f16_l8_' is not registered
2024-05-04 23:29:15 [BitBLAS:DEBUG]: Artifact path is None
2024-05-04 23:29:15 [BitBLAS:DEBUG]: Artifact path is None
2024-05-04 23:29:15 [BitBLAS:DEBUG]: Artifact path is None
2024-05-04 23:29:15 [BitBLAS:DEBUG]: Artifact path is None
2024-05-04 23:29:15 [BitBLAS:DEBUG]: Artifact path is None
2024-05-04 23:29:15 [BitBLAS:DEBUG]: Artifact path is None
2024-05-04 23:29:15 [BitBLAS:DEBUG]: Artifact path is None
2024-05-04 23:29:15 [BitBLAS:DEBUG]: LocalBuilder: An exception occurred Traceback (most recent call last):
  File "/home/tianle/anaconda3/envs/bitdelta/lib/python3.9/site-packages/bitblas/3rdparty/tvm/python/tvm/exec/popen_worker.py", line 87, in main
    result = fn(*args, **kwargs)
  File "/home/tianle/anaconda3/envs/bitdelta/lib/python3.9/site-packages/bitblas/base/utils.py", line 211, in _build
    rt_mod = tvm.build(mod, target=arch.target)
  File "/home/tianle/anaconda3/envs/bitdelta/lib/python3.9/site-packages/bitblas/3rdparty/tvm/python/tvm/driver/build_module.py", line 297, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "/home/tianle/anaconda3/envs/bitdelta/lib/python3.9/site-packages/bitblas/3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/tianle/anaconda3/envs/bitdelta/lib/python3.9/site-packages/bitblas/3rdparty/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  62: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  61: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  60: tvm::codegen::Build(tvm::IRModule, tvm::Target)
  59: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  58: tvm::codegen::BuildCUDA(tvm::IRModule, tvm::Target)
  57: tvm::codegen::CodeGenC::AddFunction(tvm::GlobalVar const&, tvm::tir::PrimFunc const&)
  56: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
  55: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
  54: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
  53: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
  52: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
  51: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::DeclBufferNode const*)
  50: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
  49: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
  48: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  47: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AllocateNode const*)
  46: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  45: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AllocateNode const*)
  44: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  43: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AllocateNode const*)
  42: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  41: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
  40: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
  39: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  38: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::SeqStmtNode const*)
  37: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::ForNode const*)
  36: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::ForNode const*)
  35: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  34: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::SeqStmtNode const*)
  33: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::ForNode const*)
  32: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::ForNode const*)
  31: tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  30: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::BufferStoreNode const*)
  29: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
  28: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
  27: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::CastNode const*, std::ostream&)
  26: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
  25: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
  24: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
  23: tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
  22: tvm::codegen::PrintBinaryIntrinsic(tvm::tir::CallNode const*, char const*, std::ostream&, tvm::codegen::CodeGenC*)
  21: tvm::codegen::CodeGenCUDA::PrintVecBinaryOp(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::DataType, tvm::PrimExpr, tvm::PrimExpr, std::ostream&)
  20: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
  19: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
  18: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
  17: tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
  16: tvm::codegen::PrintBinaryIntrinsic(tvm::tir::CallNode const*, char const*, std::ostream&, tvm::codegen::CodeGenC*)
  15: tvm::codegen::CodeGenCUDA::PrintVecBinaryOp(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::DataType, tvm::PrimExpr, tvm::PrimExpr, std::ostream&)
  14: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
  13: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
  12: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
  11: tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
  10: tvm::codegen::PrintBinaryIntrinsic(tvm::tir::CallNode const*, char const*, std::ostream&, tvm::codegen::CodeGenC*)
  9: tvm::codegen::CodeGenCUDA::PrintVecBinaryOp(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::DataType, tvm::PrimExpr, tvm::PrimExpr, std::ostream&)
  8: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
  7: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
  6: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
  5: tvm::codegen::CodeGenC::VisitExpr_(tvm::tir::CallNode const*, std::ostream&)
  4: tvm::codegen::PrintBinaryIntrinsic(tvm::tir::CallNode const*, char const*, std::ostream&, tvm::codegen::CodeGenC*)
  3: tvm::codegen::CodeGenCUDA::PrintVecBinaryOp(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::DataType, tvm::PrimExpr, tvm::PrimExpr, std::ostream&)
  2: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
  1: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
  0: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::RampNode const*, std::ostream&)
  File "/home/t-leiwang/mlc_workspace/BitBLAS/3rdparty/tvm/src/target/source/codegen_cuda.cc", line 1224
ValueError: Check failed: lanes <= 4 (8 vs. 4) : Ramp of more than 4 lanes is not allowed.
LeiWang1999 commented 1 month ago

Hi @chromecast56, apologies for any confusion. I thought that you were attempting to use int8xint1 instead of float16xint1.

Something should be identified:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import bitblas
import torch

os.environ["PATH"] = os.environ["PATH"]+":/usr/local/cuda/bin/"

bitblas.set_log_level("DEBUG")
matmul_config = bitblas.MatmulConfig(
    M=1,  # M dimension
    N=1024,  # N dimension
    K=1024,  # K dimension
    A_dtype="float16",  # activation A dtype
    W_dtype="uint1",  # weight W dtype
    accum_dtype="float16",  # accumulation dtype
    out_dtype="float16",  # output dtype
    layout="nt",  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    with_bias=False,  # bias
    # configs for weight only quantization
    group_size=None,  # setting for grouped quantization
    with_scaling=False,  # setting for scaling factor
    with_zeros=False,  # setting for zeros
    zeros_mode=None,  # setting for how to calculating zeros
)

matmul = bitblas.Matmul(config=matmul_config)

# Create input matrices
input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 2, (1024, 1024), dtype=torch.int8).cuda()

# Transform weight tensor to int4 data type
weight_tensor_int1 = matmul.transform_weight(weight_tensor)

# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int1)

# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)

'''
Ref output: tensor([[265.7500, 273.2500, 268.2500,  ..., 259.0000, 257.5000, 260.2500]],
       device='cuda:0', dtype=torch.float16)
BitBLAS output: tensor([[265.5000, 273.5000, 268.5000,  ..., 259.0000, 257.5000, 260.0000]],
       device='cuda:0', dtype=torch.float16)
'''
chromecast56 commented 1 month ago

This makes sense, thank you so much for the clarification @LeiWang1999! So for benchmarking purposes, I assume it would be reasonable to treat float16xint1 as float16xuint1 with zeropoints.

chromecast56 commented 1 month ago

I don't want to bother you too much @LeiWang1999, but in the code you just linked,torch.testing.assert_close fails for me with W_dtype='uint1' but works for other W_dtype - I was wondering if you had any idea to resolve this/if this is expected?

Ref output: tensor([[235.8750, 243.6250, 249.6250,  ..., 232.5000, 256.0000, 250.3750]],
       device='cuda:0', dtype=torch.float16)
BitBLAS output: tensor([[196.6250, 207.1250, 200.7500,  ..., 187.3750, 206.5000, 207.5000]],
       device='cuda:0', dtype=torch.float16)
LeiWang1999 commented 1 month ago

hi @chromecast56 would you mind provide more information, like your device and reproduce scripts, I'll be glad to help you to resolve this : )

chromecast56 commented 1 month ago

On an A100 80GB machine, the reproduce script:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import triton
import triton.language as tl

import bitblas
import torch

os.environ["PATH"] = os.environ["PATH"]+":/usr/local/cuda/bin/"

bitblas.set_log_level("DEBUG")
matmul_config = bitblas.MatmulConfig(
    M=1,  # M dimension
    N=1024,  # N dimension
    K=1024,  # K dimension
    A_dtype="float16",  # activation A dtype
    W_dtype="uint1",  # weight W dtype
    accum_dtype="float16",  # accumulation dtype
    out_dtype="float16",  # output dtype
    layout="nt",  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    with_bias=False,  # bias
    # configs for weight only quantization
    group_size=None,  # setting for grouped quantization
    with_scaling=False,  # setting for scaling factor
    with_zeros=False,  # setting for zeros
    zeros_mode=None,  # setting for how to calculating zeros
)

matmul = bitblas.Matmul(config=matmul_config)

# Create input matrices
input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 2, (1024, 1024), dtype=torch.int8).cuda()

# Transform weight tensor to int4 data type
weight_tensor_int4 = matmul.transform_weight(weight_tensor)

# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int4)

# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)
LeiWang1999 commented 1 month ago

@chromecast56, the issue is related to the recent CUDA graph support. This error occurs if you use the 0.0.1.dev3 BitBLAS PyPI package. I was using the upstream source from github so my output is correct. I've submitted the dev4 to PyPI; please check it out.

pip install bitblas==0.0.1.dev4
chromecast56 commented 1 month ago

awesome, thanks!

LeiWang1999 commented 1 month ago

@chromecast56 btw dude, I found that the fast_decoding was disabled by default when I developed the dev4 release, which could lead to reduced performance. For latency benchmarking, consider building from the upstream source :)

LeiWang1999 commented 1 month ago

checkout:

pip install bitblas==0.0.1.dev5
chromecast56 commented 1 month ago

@LeiWang1999 How should I cite this btw? Is there an associated paper or should I just link the repo?

LeiWang1999 commented 1 month ago

hi @chromecast56 , thanks and our paper will be published in Jul. 2024. The paper citation bibtex will be as follow. You can also cite the BitBLAS repo before the publication.

@inproceedings{ladder,
author = {Lei Wang and Lingxiao Ma and Shijie Cao and Quanlu Zhang and Jilong Xue and Yining Shi and Ningxin Zheng and Ziming Miao and Fan Yang and Ting Cao and Yuqing Yang and Mao Yang},
title = {Ladder: Enabling Efficient Low-Precision Deep Learning Computing through Hardware-aware Tensor Transformation},
booktitle = {18th USENIX Symposium on Operating Systems Design and Implementation (OSDI 24)},
year = {2024}
}