vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
29.08k stars 4.33k forks source link

[Bug]: CUDA error when running mistral-7b + lora with tensor_para=8 #4756

Open sfc-gh-zhwang opened 5 months ago

sfc-gh-zhwang commented 5 months ago

Your current environment

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.10.213-201.855.amzn2.x86_64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
GPU 4: NVIDIA A100-SXM4-40GB
GPU 5: NVIDIA A100-SXM4-40GB
GPU 6: NVIDIA A100-SXM4-40GB
GPU 7: NVIDIA A100-SXM4-40GB

Nvidia driver version: 535.161.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             96
On-line CPU(s) list:                0-95
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          2
Stepping:                           7
BogoMIPS:                           5999.99
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          1.5 MiB (48 instances)
L1i cache:                          1.5 MiB (48 instances)
L2 cache:                           48 MiB (48 instances)
L3 cache:                           71.5 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-23,48-71
NUMA node1 CPU(s):                  24-47,72-95
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        KVM: Mitigation: VMX unsupported
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, STIBP disabled, RSB filling
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] onnx==1.15.0rc2
[pip3] optree==0.10.0
[pip3] pytorch-quantization==2.1.2
[pip3] pytorch-triton==2.2.0+e28a256d7
[pip3] torch==2.3.0
[pip3] torch-tensorrt==2.3.0a0
[pip3] torchdata==0.7.1a0
[pip3] torchtext==0.17.0a0
[pip3] torchvision==0.18.0a0
[pip3] triton==2.3.0
[pip3] vllm-nccl-cu12==2.18.1.0.4.0
[conda] Could not collectROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.1
vLLM Build Flags:
CUDA Archs: 5.2 6.0 6.1 7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X  NV12    NV12    NV12    NV12    NV12    NV12    NV12    0-23,48-71  0       N/A
GPU1    NV12     X  NV12    NV12    NV12    NV12    NV12    NV12    0-23,48-71  0       N/A
GPU2    NV12    NV12     X  NV12    NV12    NV12    NV12    NV12    0-23,48-71  0       N/A
GPU3    NV12    NV12    NV12     X  NV12    NV12    NV12    NV12    0-23,48-71  0       N/A
GPU4    NV12    NV12    NV12    NV12     X  NV12    NV12    NV12    24-47,72-95 1       N/A
GPU5    NV12    NV12    NV12    NV12    NV12     X  NV12    NV12    24-47,72-95 1       N/A
GPU6    NV12    NV12    NV12    NV12    NV12    NV12     X  NV12    24-47,72-95 1       N/A
GPU7    NV12    NV12    NV12    NV12    NV12    NV12    NV12     X  24-47,72-95 1       N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

When running below, enable lora for mistral-7b model with tensor parall=8, will throw cuda error. Full log is here

from vllm import LLM, SamplingParams

llm = LLM(
    model="/models/mistral-7b",
    enable_lora=True,
    tensor_parallel_size=8,
)
mgoin commented 5 months ago

Hi @sfc-gh-zhwang, FWIW I was able to run this with TP=2 on 2xA6000 using vllm==0.4.2.

from vllm import LLM

llm = LLM(
    model="mistralai/Mistral-7B-Instruct-v0.2",
    enable_lora=True,
    tensor_parallel_size=2,
)

print(llm.generate("Hello"))

Output:

2024-05-22 06:43:47,998 INFO worker.py:1749 -- Started a local Ray instance.
INFO 05-22 06:43:49 llm_engine.py:100] Initializing an LLM engine (v0.4.2) with config: model='mistralai/Mistral-7B-Instruct-v0.2', speculative_config=None, tokenizer='mistralai/Mistral-7B-Instruct-v0.2', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=mistralai/Mistral-7B-Instruct-v0.2)
INFO 05-22 06:43:53 utils.py:660] Found nccl from library /root/.config/vllm/nccl/cu12/libnccl.so.2.18.1
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:53 utils.py:660] Found nccl from library /root/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 05-22 06:43:54 selector.py:81] Cannot use FlashAttention-2 backend because the flash_attn package is not found. Please install it for better performance.
INFO 05-22 06:43:54 selector.py:32] Using XFormers backend.
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:54 selector.py:81] Cannot use FlashAttention-2 backend because the flash_attn package is not found. Please install it for better performance.
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:54 selector.py:32] Using XFormers backend.
INFO 05-22 06:43:56 pynccl_utils.py:43] vLLM is using nccl==2.18.1
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:56 pynccl_utils.py:43] vLLM is using nccl==2.18.1
INFO 05-22 06:43:57 utils.py:132] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:57 utils.py:132] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 05-22 06:43:58 weight_utils.py:199] Using model weights format ['*.safetensors']
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:58 weight_utils.py:199] Using model weights format ['*.safetensors']
INFO 05-22 06:44:00 model_runner.py:175] Loading model weights took 6.7544 GB
(RayWorkerWrapper pid=523946) INFO 05-22 06:44:01 model_runner.py:175] Loading model weights took 6.7544 GB
INFO 05-22 06:44:07 distributed_gpu_executor.py:45] # GPU blocks: 33401, # CPU blocks: 4096
INFO 05-22 06:44:08 model_runner.py:937] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-22 06:44:08 model_runner.py:941] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(RayWorkerWrapper pid=523946) INFO 05-22 06:44:08 model_runner.py:937] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(RayWorkerWrapper pid=523946) INFO 05-22 06:44:08 model_runner.py:941] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 05-22 06:44:13 custom_all_reduce.py:246] Registering 2275 cuda graph addresses
INFO 05-22 06:44:13 model_runner.py:1017] Graph capturing finished in 5 secs.
(RayWorkerWrapper pid=523946) INFO 05-22 06:44:13 custom_all_reduce.py:246] Registering 2275 cuda graph addresses
(RayWorkerWrapper pid=523946) INFO 05-22 06:44:13 model_runner.py:1017] Graph capturing finished in 5 secs.
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.56it/s]
[RequestOutput(request_id=0, prompt='Hello', prompt_token_ids=[1, 22557], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=", I'm quite new to Linux so I'm really sorry if this", token_ids=[28725, 315, 28742, 28719, 3448, 633, 298, 19486, 579, 315, 28742, 28719, 1528, 7371, 513, 456], cumulative_logprob=-27.695576645433903, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1716360253.5880315, last_token_time=1716360253.5880315, first_scheduled_time=1716360253.5906928, first_token_time=1716360253.6384084, time_in_queue=0.0026612281799316406, finished_time=1716360253.8712075), lora_request=None)]
sfc-gh-zhwang commented 5 months ago

@mgoin it's just tp=8 doesn't work.

sfc-gh-zhwang commented 5 months ago

@FurtherAI in case you have some idea 😃

sfc-gh-zhwang commented 5 months ago

Further narrow down to this line

punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)

where, for tp=8 (error out), the tensor sizes are:

buffer: [32768, 16]
x: [32768, 512]
wa_t_all: [1, 1, 16, 512]

while for tp=4 (working), the tensor sizes are

buffer: [32768, 16]
x: [32768, 1024]
wa_t_all: [1, 1, 16, 1024]

Still trying to figure out what is the magic around 1024 -> 512

FurtherAI commented 5 months ago

Tracked it a little further. Seems to be due to the sequence length. Not sure why, from a brief glance, the kernel shouldn't care about the sequence length. I found 65536 to work, 32768 and 16384 to not work and 8192, 4096 to work and didn't test more. So for now, @sfc-gh-zhwang, run with a different seq length.

Here's some code to reproduce:

import vllm._punica_C as punica_kernels
seq_length, rank = 32768, 16
buffer = torch.randn((seq_length, rank), device='cuda', dtype=torch.float32)
x = torch.randn((seq_length, 512), device='cuda', dtype=torch.bfloat16)
wa_t_all = torch.randn((1, 1, rank, 512), device='cuda', dtype=torch.bfloat16)
indicies = torch.full((seq_length,), 1, device='cuda', dtype=torch.int64)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, 0, 1.0)

torch.cuda.synchronize()
sfc-gh-zhwang commented 4 months ago

I think I found the root cause, basically the code here will overflow X for certain tensor shapes. I think the solution should be adding a condition like if (threadIdx.y * tx * vec_size < feat_in). But I think we should fold this into line 84 and just change for (tile_idx = 1; -> for (tile_idx = 0;? ROCM is doing this anyway: https://github.com/vllm-project/vllm/blob/a377f0bd5e1fa0ca069e3dbf28f4de5af64d0bb1/csrc/punica/bgmv/bgmv_impl.cuh#L196