vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.11k stars 3.83k forks source link

[Bug]: "apply_gptq_marlin_linear" Error When TP > 1 #6889

Open sitabulaixizawaluduo opened 1 month ago

sitabulaixizawaluduo commented 1 month ago

Your current environment

PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.1
Libc version: glibc-2.35

Python version: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.16.20-3.el7.bzl.x86_64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA L40
GPU 1: NVIDIA L40
GPU 2: NVIDIA L40
GPU 3: NVIDIA L40
GPU 4: NVIDIA L40
GPU 5: NVIDIA L40
GPU 6: NVIDIA L40
GPU 7: NVIDIA L40

Nvidia driver version: 535.104.12
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 57 bits virtual
Byte Order:                      Little Endian
CPU(s):                          128
On-line CPU(s) list:             0-127
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Platinum 8358P CPU @ 2.60GHz
CPU family:                      6
Model:                           106
Thread(s) per core:              2
Core(s) per socket:              32
Socket(s):                       2
Stepping:                        6
CPU max MHz:                     3400.0000
CPU min MHz:                     800.0000
BogoMIPS:                        5200.00
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid fsrm md_clear pconfig flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       3 MiB (64 instances)
L1i cache:                       2 MiB (64 instances)
L2 cache:                        80 MiB (64 instances)
L3 cache:                        96 MiB (2 instances)
NUMA node(s):                    2
NUMA node0 CPU(s):               0-31,64-95
NUMA node1 CPU(s):               32-63,96-127
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Vulnerable: eIBRS with unprivileged eBPF
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] optree==0.11.0
[pip3] sentence-transformers==2.2.2
[pip3] torch==2.3.1
[pip3] torchaudio==2.3.1
[pip3] torchelastic==0.2.2
[pip3] torchvision==0.18.1
[pip3] transformers==4.43.1
[pip3] triton==2.3.1
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mkl-service               2.4.0           py310h5eee18b_1  
[conda] mkl_fft                   1.3.8           py310h5eee18b_0  
[conda] mkl_random                1.2.4           py310hdb19cb5_0  
[conda] numpy                     1.26.4          py310h5f9d8c6_0  
[conda] numpy-base                1.26.4          py310hb5e798b_0  
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] optree                    0.11.0                   pypi_0    pypi
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] sentence-transformers     2.2.2                    pypi_0    pypi
[conda] torch                     2.3.1                    pypi_0    pypi
[conda] torchaudio                2.3.1               py310_cu121    pytorch
[conda] torchelastic              0.2.2                    pypi_0    pypi
[conda] torchvision               0.18.1                   pypi_0    pypi
[conda] transformers              4.43.1                   pypi_0    pypi
[conda] triton                    2.3.1                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.3.post1
vLLM Build Flags:
CUDA Archs: 7.0 7.5 8.0 8.6 8.9 9.0; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      PIX     PXB     PXB     SYS     SYS     SYS     SYS     PXB     0-31,64-95      0               N/A
GPU1    PIX      X      PXB     PXB     SYS     SYS     SYS     SYS     PXB     0-31,64-95      0               N/A
GPU2    PXB     PXB      X      PXB     SYS     SYS     SYS     SYS     PXB     0-31,64-95      0               N/A
GPU3    PXB     PXB     PXB      X      SYS     SYS     SYS     SYS     PIX     0-31,64-95      0               N/A
GPU4    SYS     SYS     SYS     SYS      X      PIX     PXB     PXB     SYS     32-63,96-127    1               N/A
GPU5    SYS     SYS     SYS     SYS     PIX      X      PXB     PXB     SYS     32-63,96-127    1               N/A
GPU6    SYS     SYS     SYS     SYS     PXB     PXB      X      PXB     SYS     32-63,96-127    1               N/A
GPU7    SYS     SYS     SYS     SYS     PXB     PXB     PXB      X      SYS     32-63,96-127    1               N/A
NIC0    PXB     PXB     PXB     PIX     SYS     SYS     SYS     SYS      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_bond_0

🐛 Describe the bug

The Code

from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
model_path = '/models/Llama-3-8B_w8a16_packed_quantize'

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
prompts = [
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0,
                                 max_tokens=256,
                                 stop=['<|end_of_text|>'],
                                 skip_special_tokens=False)

# Create an LLM.
llm = LLM(model=model_path,
                  tensor_parallel_size=2,
                  disable_custom_all_reduce=True,
                  trust_remote_code=True,
                  worker_use_ray=True,
                  quantization='compressed-tensors',          
                  enable_chunked_prefill=False,
                  dtype='bfloat16'
                 )

The Mode Config with Compressed tensors

  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "compression_config": {
    "config_groups": {
      "group_1": {
        "input_activations": null,
        "output_activations": null,
        "targets": [
          "Linear"
        ],
        "weights": {
          "block_structure": null,
          "dynamic": false,
          "group_size": null,
          "num_bits": 8,
          "observer": "minmax",
          "observer_kwargs": {},
          "strategy": "channel",
          "symmetric": true,
          "type": "int"
        }
      }
    },
    "format": "pack-quantized",
    "global_compression_ratio": null,
    "ignore": [
      "lm_head"
    ],
    "quant_method": "compressed-tensors",
    "quantization_status": "calibration"
  },
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.42.3",
  "use_cache": true,
  "vocab_size": 128256

When tp=1, it is working fine, But When tp = 2 or tp = 4, NCCL ERROR occurs The Model is Normal load, ERROR occurs when profile_run, But when I use vllm0.5.1, I don't have this problem。 Hopefully this will help to see what's wrong and how I should fix it,Thanks!

sitabulaixizawaluduo commented 1 month ago

The error message

CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f97fe0cf897 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f97fe07fb25 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f97fe1a7718 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f97b20738e6 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7f97b20779e8 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7f97b207d05c in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f97b207ddcc in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdbbf4 (0x7f980e87fbf4 in /opt/conda/lib/python3.10/site-packages/../../libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f980f9f6ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: clone + 0x44 (0x7f980fa87bf4 in /usr/lib/x86_64-linux-gnu/libc.so.6)

[rank1]:[E ProcessGroupNCCL.cpp:1414] [PG 2 Rank 1] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f97fe0cf897 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f97fe07fb25 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f97fe1a7718 in /opt/conda/lib/python3.10/site-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f97b20738e6 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7f97b20779e8 in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7f97b207d05c in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f97b207ddcc in /opt/conda/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdbbf4 (0x7f980e87fbf4 in /opt/conda/lib/python3.10/site-packages/../../libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f980f9f6ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: clone + 0x44 (0x7f980fa87bf4 in /usr/lib/x86_64-linux-gnu/libc.so.6)
robertgshaw2-neuralmagic commented 1 month ago

Thanks for the detailed report. I introduced an bug loading channelwise weights on TP>1 for compressed-tensors

The following PR should resolve your issue:

It will be in the next release. Do you need any help building vllm from source?

robertgshaw2-neuralmagic commented 1 month ago

Also, thanks for using llm-compressor.

If you have any feedback on the UX or anything else, please don't hesitate to raise an issue in the llm-compressor repo so that we can continue to improve the experience.

sitabulaixizawaluduo commented 1 month ago

Also, thanks for using llm-compressor.

If you have any feedback on the UX or anything else, please don't hesitate to raise an issue in the llm-compressor repo so that we can continue to improve the experience.

Thanks !