NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.29k stars 925 forks source link

Is W4A(FP)8 quant not supported with bf16 datatype? #1843

Closed wxsms closed 2 months ago

wxsms commented 3 months ago

System Info

ubuntu, with Ada GPUs. tllm version: 0.11.0.dev2024061800

Who can help?

@Tracin

Information

Tasks

Reproduction

use example/quantization/quantize.py to quant a model like this (I am using Llama):

python3 ./quantization/quantize.py \
        --model_dir /mnt/models/source \
        --dtype bfloat16 \
        --qformat w4a8_awq \
        --output_dir /tmp/checkpoint \
        --calib_tp_size 4 \
        --tp_size 1

Expected behavior

the quantization should work

actual behavior

not working with error: FP8 is unsupported on with BF16 scales and zero-points!

additional notes

I notice that in tensorrt_llm/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp there is a snip of code like this:

#if defined(ENABLE_BF16)
    else if (mType == nvinfer1::DataType::kBF16)
    {
        if (quant_algo & FP8_ALPHA)
        {
            // FP8 requires at least sm89 devices
            if (mArch < 89)
            {
                TLLM_THROW("W4A(fp)8 kernel is unsupported on pre-Ada (sm<89) architectures!");
            }
            TLLM_THROW("FP8 is unsupported on with BF16 scales and zero-points!");
        }
        else
        {
            if (quant_algo & ZERO)
            {
                // has zeros
                m_weightOnlyGroupwiseGemmRunner
                    = std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
                        cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>>();
            }
            else
            {
                // no zeros
                m_weightOnlyGroupwiseGemmRunner
                    = std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
                        cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>();
            }
        }
        mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
            mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise);
        mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise;
    }
#endif

I not very sure but is this a mistake? though the error message is mentioning zero-points, but it throws without zero condition check (which in in the next block I think?).

Barry-Delaney commented 3 months ago

@wxsms thanks for the feedback. w4a8_awq with BF16 data type is not supported yet, we will add it in the following updates.

nv-guomingz commented 2 months ago

Hi @wxsms could we close this ticket now?

wxsms commented 2 months ago

Hi @wxsms could we close this ticket now?

It's okay. we can also close this issue while this feature is fully supported. You may close it on your demand. Thanks

nv-guomingz commented 2 months ago

Thanks @wxsms . Please feel free to reopen it if neede.