vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
28.86k stars 4.28k forks source link

[Bug]: CUDA out of memory when setting prompt_logprobs with larger batch_size #5424

Open qaz-wsx-1 opened 4 months ago

qaz-wsx-1 commented 4 months ago

Your current environment

Collecting environment information...
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.31

Python version: 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-182-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
Nvidia driver version: 535.171.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             256
On-line CPU(s) list:                0-255
Thread(s) per core:                 2
Core(s) per socket:                 64
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          AuthenticAMD
CPU family:                         25
Model:                              1
Model name:                         AMD EPYC 7713 64-Core Processor
Stepping:                           1
Frequency boost:                    enabled
CPU MHz:                            1499.385
CPU max MHz:                        2000.0000
CPU min MHz:                        1500.0000
BogoMIPS:                           4000.18
Virtualization:                     AMD-V
L1d cache:                          4 MiB
L1i cache:                          4 MiB
L2 cache:                           64 MiB
L3 cache:                           512 MiB
NUMA node0 CPU(s):                  0-63,128-191
NUMA node1 CPU(s):                  64-127,192-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] transformers==4.41.2
[pip3] triton==2.3.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] torch                     2.3.0                    pypi_0    pypi
[conda] transformers              4.41.2                   pypi_0    pypi
[conda] triton                    2.3.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.3
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      73,201  1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

torch.cuda.OutOfMemoryError: CUDA out of memory when setting prompt_logprobs=1 for a large batch


from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
llm = LLM(model=model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

messages = [
    {"role": "user", "content": "What is the capital of France? Repeat 50 times."},
    {"role": "assistant", "content": "The capital of France is Paris." * 50},
]
batch_size = 10
prompt_token_ids = [tokenizer.apply_chat_template(messages)] * batch_size
print("Prompt token ids:", len(prompt_token_ids[0]))

sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=1)
outs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params)
moussaKam commented 10 hours ago

Hi, same problem here, any idea how to fix it?

tjohnson31415 commented 5 hours ago

It is known that requesting prompt_logprobs can cause a spike in memory usage and lead to an crash. OOMs due to logprobs are being tracked on: https://github.com/vllm-project/vllm/issues/5907

There is no fix yet, but the current workaround is to use --enable-chunked-prefill. This reduces the meomry spike due to the limit on the number of tokens processed in a batch, but is not supported yet supported for all other parameter configurations.