vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
27.76k stars 4.1k forks source link

[Bug]: The VRAM usage of calculating log_probs is not considered in profile run #5067

Open Conless opened 4 months ago

Conless commented 4 months ago

Your current environment

The output of `python collect_env.py`

Collecting environment information...
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.29.3
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.9.1-arch1-1-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070 Ti SUPER
Nvidia driver version: 550.78
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               24
On-line CPU(s) list:                  0-23
Vendor ID:                            GenuineIntel
Model name:                           13th Gen Intel(R) Core(TM) i7-13700K
CPU family:                           6
Model:                                183
Thread(s) per core:                   2
Core(s) per socket:                   16
Socket(s):                            1
Stepping:                             1
CPU max MHz:                          5400.0000
CPU min MHz:                          800.0000
BogoMIPS:                             6837.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect user_shstk avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq tme rdpid movdiri movdir64b fsrm md_clear serialize pconfig arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                       VT-x
L1d cache:                            640 KiB (16 instances)
L1i cache:                            768 KiB (16 instances)
L2 cache:                             24 MiB (10 instances)
L3 cache:                             30 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-23
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] triton==2.3.0
[pip3] vllm-nccl-cu12==2.18.1.0.4.0
[conda] Could not collectROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.2
vLLM Build Flags:
CUDA Archs: 7.0 7.5 8.0 8.6 8.9 9.0+PTX; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X  0-23    0       N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

I encountered an unexpected CUDA out of memory error while adding a new feature for LoRA into vLLM. After experimenting with different settings, I discovered that the bug only appears when prompt_logprobs in SamplingParams is set to a non-zero value and a long prompt length is used, as mentioned in #1532. I then tried to locate the bug and found the following (some unimportant tracebacks are omitted):

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/vllm/worker/model_runner.py", line 721, in execute_model
[rank0]:     output = self.model.sample(
[rank0]:   File "/workspace/vllm/model_executor/models/llama.py", line 381, in sample
[rank0]:     next_tokens = self.sampler(logits, sampling_metadata)
[rank0]:   File "/workspace/vllm/model_executor/layers/sampler.py", line 72, in forward
[rank0]:     logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
[rank0]:   File "/workspace/vllm/model_executor/layers/sampler.py", line 208, in _apply_penalties
[rank0]:     output_bin_counts, output_mask = _get_bin_counts_and_mask(
[rank0]:   File "/workspace/vllm/model_executor/layers/sampler.py", line 143, in _get_bin_counts_and_mask
[rank0]:     bin_counts = torch.zeros((num_seqs, vocab_size + 1),
[rank0]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.41 GiB. GPU

The bug is caused by calculations in _get_bin_counts_and_mask during the sampling phase. When prompt_logprobs is enabled, the log probabilities of all tokens in the prompt (up to 8192 for Llama 3, which I am using) are calculated, leading to a memory usage of up to

$$num\ tokens \times vocab\ size \times 4 \text{Bytes} = 8192 \times 128256 \times 4 \text{Bytes} = 7.8 \text{GiB}$$

However, this memory usage is not predicted in profile_run(), where the sampling parameters are set as:

# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)

This only considers the calculation of log probabilities for up to 256 tokens (the maximum batched sequence count).

Conless commented 4 months ago

A trivial way to solve this is to add a limitation $log\ prob\ tokens < max\ seqs$ in the following function in scheduler.py:

def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
        assert num_new_tokens != 0
        assert num_new_seqs != 0
        return (self.num_batched_tokens + num_new_tokens <= self.token_budget
                and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)

If this solution is acceptable, I can submit a pull request later.

robertgshaw2-neuralmagic commented 4 months ago

Agree this is an issue that needs to be fixed.

I don't quite see how log_prob_tokens < max_seqs is the right solution though ... isn't this a bit too course-grained?

Conless commented 4 months ago

@robertgshaw2-neuralmagic I agree with you. Another solution I came up with, which is more fine-grained, is to add a new argument max_num_logprobs to EngineArgs (defaulting to the value of max_num_seqs). However, I'm concerned that this argument might rarely be used.

What do you think about it?

robertgshaw2-neuralmagic commented 4 months ago

@robertgshaw2-neuralmagic I agree with you. Another solution I came up with, which is more fine-grained, is to add a new argument max_num_logprobs to EngineArgs (defaulting to the value of max_num_seqs). However, I'm concerned that this argument might rarely be used.

What do you think about it?

I think we should have some user controlled max_num_logprobs with a sensible default. Let me ask the rest of the group

Then we will need to:

Conless commented 4 months ago

@robertgshaw2-neuralmagic Thank you for considering this.

Then we will need to:

  • update profiling logic to take this into account
  • update scheduler logic to take this into account

This modification should not be too difficult. Given the concise structure of the current code, it seems feasible to implement by adding the logic to the generation of prompts in profile_run and to can_schedule in the scheduler.

rangehow commented 4 months ago

Just ran into this problem : (

mgoin commented 4 months ago

@Conless would you be willing to tackle the support for this? I agree it would be a nice improvement - I have run into this when performing LLM evaluations on MMLU, this requires a lot of logprobs.

Conless commented 4 months ago

@mgoin No problem, I'd be delighted to tackle it.

likaixin2000 commented 3 months ago

Had the same problem :( Are there any quick fixes to this?