vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
27.05k stars 3.97k forks source link

[Bug]: OOM when setting prompt_logprobs=1 #5550

Open janphilippfranken opened 3 months ago

janphilippfranken commented 3 months ago

Your current environment

Collecting environment information...
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.29.5
Libc version: glibc-2.31

Python version: 3.10.0 (default, Mar  3 2022, 09:58:08) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-162-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB
Nvidia driver version: 535.54.03
cuDNN version: Probably one of the following:
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.4.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn.so.8.4.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn.so.8.4.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn.so.8.4.1
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1
/usr/local/cuda-11.4/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1
/usr/local/cuda-11.5/targets/x86_64-linux/lib/libcudnn.so.8.4.1
/usr/local/cuda-11.5/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1
/usr/local/cuda-11.5/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1
/usr/local/cuda-11.5/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1
/usr/local/cuda-11.5/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1
/usr/local/cuda-11.5/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1
/usr/local/cuda-11.5/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn.so.8.4.1
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn.so.8.4.1
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn.so.8.4.1
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.4.1
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.4.1
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.4.1
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.4.1
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.4.1
/usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.4.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             128
On-line CPU(s) list:                0-127
Thread(s) per core:                 2
Core(s) per socket:                 32
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          AuthenticAMD
CPU family:                         25
Model:                              1
Model name:                         AMD EPYC 7543 32-Core Processor
Stepping:                           1
Frequency boost:                    enabled
CPU MHz:                            1657.430
CPU max MHz:                        2800.0000
CPU min MHz:                        1500.0000
BogoMIPS:                           5599.81
Virtualization:                     AMD-V
L1d cache:                          2 MiB
L1i cache:                          2 MiB
L2 cache:                           32 MiB
L3 cache:                           512 MiB
NUMA node0 CPU(s):                  0-31,64-95
NUMA node1 CPU(s):                  32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] transformers==4.41.2
[pip3] triton==2.3.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] torch                     2.3.0                    pypi_0    pypi
[conda] transformers              4.41.2                   pypi_0    pypi
[conda] triton                    2.3.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.0.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X  0-7,64-71   0-1     N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

I have a standard setup like:

model = LLM(
            model=model,
            download_dir=download_dir,
            dtype=dtype,
            tensor_parallel_size=tensor_parallel_size,
            quantization=quantization if quantization != "none" else None,
        )

And running a function like:

def batch_prompt(
        self, 
        prompts: List[str], 
        max_new_tokens: Optional[int] = 500,
        do_sample: Optional[bool] = True,
        top_p: Optional[float] = 0.9,
        top_k: Optional[int] = -1,
        temperature: Optional[float] = 0.1,
        num_return_sequences: Optional[int] = 1,
        best_of: Optional[int] = 1,
        use_beam_search: Optional[bool] = False,
        presence_penalty: Optional[float] = 0.0,
        frequency_penalty: Optional[float] = 0.0,
    ) -> List[str]:
        """Batched text generation."""     
        sampling_params = SamplingParams(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_tokens=max_new_tokens,
            n=num_return_sequences,
            best_of=1,
            use_beam_search=use_beam_search,
            presence_penalty=presence_penalty,
            frequency_penalty=frequency_penalty,
        )

        outputs = self.model.generate(
            prompts=prompts,
            sampling_params=sampling_params,
        )

        generations = []
        for output in outputs: 
            for generated_sequence in output.outputs:
                generations.append(generated_sequence.text)

        return generations

works fine with very long prompts and a very large batch size.

However, as soon as I do something like

sampling_params = SamplingParams(
            temperature=0,
            max_tokens=1,
            n=1,
            prompt_logprobs=1,
            spaces_between_special_tokens=False,
        )

I.e., prompt_logprobs=1 instead of default (None), I immediately get OOM for the exact same prompts which does not make sense to me? It should just return the logprobs in addition to the generations but not affect things otherwise?

My OOM error:

Processed prompts:   0%|    | 0/5 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]Error executing job with overrides: []
Traceback (most recent call last):
  File "/sailhome/jphilipp/research_projects/star-gate-human-eval/experiments/v1/logprobs.py", line 47, in main
    batch_logprobs = model.prompt_logprobs(
  File "/sailhome/jphilipp/research_projects/star-gate-human-eval/src/stargate/vllm_inference_model.py", line 43, in prompt_logprobs
    output_responses = self.model.generate(
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/vllm/utils.py", line 691, in inner
    return fn(*args, **kwargs)
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 304, in generate
    outputs = self._run_engine(use_tqdm=use_tqdm)
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 556, in _run_engine
    step_outputs = self.llm_engine.step()
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 776, in step
    output = self.model_executor.execute_model(
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 91, in execute_model
    output = self.driver_worker.execute_model(execute_model_req)
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/vllm/worker/worker.py", line 280, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 765, in execute_model
    output = self.model.sample(
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 386, in sample
    next_tokens = self.sampler(logits, sampling_metadata)
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scr/jphilipp/miniconda3/envs/stargate/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 91, in forward
    probs = torch.softmax(logits, dim=-1, dtype=torch.float)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.61 GiB. GPU 

Note that the above works if len(prompts) == 1.

zifeitong commented 3 months ago

There is a pending PR trying to address this problem: https://github.com/vllm-project/vllm/pull/5355.

Meanwhile, you can try the chunked prefill feature which worked for me as a workaround: https://docs.vllm.ai/en/latest/models/performance.html#chunked-prefill.

janphilippfranken commented 3 months ago

would you mind sharing your code? let's say i have n_prompts=10, and set prompt_logprobs=0, i'd ideally get the logprobs for all 10 prompts using a single call model.generate(prompts=prompts, sampling_params=sampling_params).

zifeitong commented 3 months ago

Something like this: model = LLM(..., enable_chunked_prefill=True, max_num_batched_tokens=512, gpu_memory_utilization=0.9)

Try smaller values of gpu_memory_utilization and/or max_num_batched_tokens if you still see OOM.