vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.31k stars 4.75k forks source link

[Bug]: profile_run Inaccurate estimation leads to gpu OutOfMemoryError #7256

Closed izhuhaoran closed 1 month ago

izhuhaoran commented 3 months ago

Your current environment

PyTorch version: 2.3.0a0+ebedce2
Is debug build: False
CUDA used to build PyTorch: 12.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.0
Libc version: glibc-2.35

Python version: 3.10.12 (main, Mar 22 2024, 16:50:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-4.19.91-014-kangaroo.2.10.13.5c249cdaf.x86_64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.3.107
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.54.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 57 bits virtual
Byte Order:                      Little Endian
CPU(s):                          48
On-line CPU(s) list:             0-47
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Processor @ 2.90GHz
CPU family:                      6
Model:                           106
Thread(s) per core:              1
Core(s) per socket:              48
Socket(s):                       1
Stepping:                        6
BogoMIPS:                        5800.00
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd avx512vbmi umip pku avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm md_clear arch_capabilities
Virtualization:                  VT-x
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       2.3 MiB (48 instances)
L1i cache:                       1.5 MiB (48 instances)
L2 cache:                        60 MiB (48 instances)
L3 cache:                        48 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-47
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.4
[pip3] onnx==1.15.0rc2
[pip3] optree==0.10.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.3.0a0+ebedce2
[pip3] torch-tensorrt==2.3.0a0
[pip3] torchdata==0.7.1a0
[pip3] torchtext==0.17.0a0
[pip3] torchvision==0.18.0a0
[pip3] transformers==4.42.4
[pip3] transformers-stream-generator==0.0.5
[pip3] triton==3.0.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.3.post1
vLLM Build Flags:
CUDA Archs: 5.2 6.0 6.1 7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    NIC0    NIC1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV6     NV6     NV6     PHB     PHB     0-47            N/A             N/A
GPU1    NV6      X      NV6     NV6     PHB     PHB     0-47            N/A             N/A
GPU2    NV6     NV6      X      NV6     PHB     PHB     0-47            N/A             N/A
GPU3    NV6     NV6     NV6      X      PHB     PHB     0-47            N/A             N/A
NIC0    PHB     PHB     PHB     PHB      X      PHB
NIC1    PHB     PHB     PHB     PHB     PHB      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1

🐛 Describe the bug

I am running a test python script test_llm.py, test_llm.py code is as follows:

Click to expand test_llm.py ```python import torch from vllm import LLM, SamplingParams import random import random import argparse import time random.seed(0) # Set the random seed for reproducibility _MB = 1 << 20 dummy_prompt = "hello " * 2000 prompts = [dummy_prompt for _ in range(512)] def test_llm(model:str, n, max_tokens, tp_size): prompts_choose = prompts[:n] # print(prompts_choose) # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=max_tokens, ignore_eos=True) # Create an LLM. llm = LLM(model=model, trust_remote_code=True, enforce_eager=True, disable_log_stats=False, max_num_seqs=n, tensor_parallel_size=tp_size, disable_custom_all_reduce=True, gpu_memory_utilization=1.0) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. torch.cuda.synchronize() time1 = time.perf_counter() outputs = llm.generate(prompts_choose, sampling_params) torch.cuda.synchronize() time2 = time.perf_counter() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() print( f"use_gpu_memory: {(total_gpu_memory - free_gpu_memory)/_MB:.4f} MB, " f"free_gpu_memory: {free_gpu_memory/_MB:.4f} MB, " f"total_gpu_memory: {total_gpu_memory/_MB:.4f} MB" ) print(f"\nllm.generate over. All Generate Time: {time2 - time1:.5f} s\n") # # Print the outputs. # for output in outputs: # prompt = output.prompt # generated_text = output.outputs[0].text # # print(f"Prompt: {prompt!r},\n") # print(f"Generated text: {generated_text!r}\n") def test(): parser = argparse.ArgumentParser(description='Test LLM') parser.add_argument('-n', type=int, default=256, help='Number of prompts') parser.add_argument('-max_tokens', type=int, default=128, help='Maximum number of tokens') parser.add_argument('-tp_size', type=int, default=1, help='Tensor Parallel Size') parser.add_argument('-model', type=str, help='Model path') args = parser.parse_args() n = args.n max_tokens = args.max_tokens tp_size = args.tp_size model = args.model test_llm(model, n, max_tokens, tp_size) test() ```

run command is as follows:

python test_llm.py -n 256 -max_tokens 128 -tp_size 1 -model {YOUR_PATH}

When I use model: qwen-7b-chat ,gpu_memory_utilization=1.0, it crashes inexplicably, with the error: torch.cuda.OutOfMemoryError: CUDA out of memory.

The error output is:

INFO 08-07 07:11:24 llm_engine.py:178] Initializing an LLM engine (v0.5.3.post1) with config: model='./models/Qwen-7B-Chat', speculative_config=None, tokenizer='./models/Qwen-7B-Chat', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=./models/Qwen-7B-Chat, use_v2_block_manager=False, enable_prefix_caching=False)
WARNING 08-07 07:11:25 tokenizer.py:129] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
INFO 08-07 07:11:25 model_runner.py:722] Starting to load model ./models/Qwen-7B-Chat...
INFO 08-07 07:11:28 model_runner.py:736] Loading model weights took 14737.2578 MB
INFO 08-07 07:11:31 gpu_executor.py:103] # GPU blocks: 7877, # CPU blocks: 512

Processed prompts:   0%|                                                                                          | 0/256 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/test.log/test_llm.py", line 92, in <module>
[rank0]:     test()
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/test.log/test_llm.py", line 87, in test
[rank0]:     test_llm(model, n, max_tokens, tp_size)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/test.log/test_llm.py", line 51, in test_llm
[rank0]:     outputs = llm.generate(prompts_choose, sampling_params)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/utils.py", line 828, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/entrypoints/llm.py", line 317, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/entrypoints/llm.py", line 570, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/engine/llm_engine.py", line 925, in step
[rank0]:     output = self.model_executor.execute_model(
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/executor/gpu_executor.py", line 111, in execute_model
[rank0]:     output = self.driver_worker.execute_model(execute_model_req)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/worker/worker_base.py", line 272, in execute_model
[rank0]:     output = self.model_runner.execute_model(
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/worker/model_runner.py", line 1463, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/model_executor/models/qwen.py", line 349, in forward
[rank0]:     hidden_states = self.transformer(input_ids, positions, kv_caches,
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/model_executor/models/qwen.py", line 303, in forward
[rank0]:     hidden_states, residual = layer(
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/model_executor/models/qwen.py", line 236, in forward
[rank0]:     hidden_states = self.mlp(
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/model_executor/models/qwen.py", line 70, in forward
[rank0]:     gate_up, _ = self.gate_up_proj(x)
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/mnt/data/tao/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/model_executor/layers/linear.py", line 349, in forward
[rank0]:     output_parallel = self.quant_method.apply(self, input_, bias)
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/model_executor/layers/linear.py", line 125, in apply
[rank0]:     return F.linear(x, layer.weight, bias)
[rank0]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.31 GiB. GPU
Processed prompts:   0%|                                                                                          | 0/256 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Then when I looked into the code for the profile_run function and added some log in model forward, I found something that might be questionable:

1. Overestimation of num_blocks in determine_num_available_blocks func:

  peak_memory = self.init_gpu_memory - free_gpu_memory

  num_gpu_blocks = int(
      (total_gpu_memory * self.cache_config.gpu_memory_utilization -
       peak_memory) // cache_block_size)

In most cases, there is a gap between init_gpu_memory and total_gpu_memory, where using total_gpu_memory to calculate num_gpu_blocks will likely result in incorrectly increasing the amount of space available for kv cache, which will cause an OOM error when gpumemory utilization=1.0

So, I tried to modify the above code to:

  num_gpu_blocks = int(
      (self.init_gpu_memory * self.cache_config.gpu_memory_utilization -
       peak_memory) // cache_block_size)

After that, I rerun the test and the output is

...
INFO 08-07 07:36:04 gpu_executor.py:103] # GPU blocks: 7823, # CPU blocks: 512
...
[rank0]:   File "/mnt/data/zhuhr/vllm_upstream/vllm/model_executor/layers/linear.py", line 125, in apply
[rank0]:     return F.linear(x, layer.weight, bias)
[rank0]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.32 GiB. GPU
Processed prompts:  23%|█████████████████▋                                                            | 58/256 [00:20<01:10,  2.80it/s, est. speed input: 5700.92 toks/s, output: 364.68 toks/s]

OOM still occurs, but GPU blocks are reduced from 7877 to 7823 compared to before the modification, and the run progress is increased from 0% to 23%

2. Strange gpu memory usage increase in _allocate_kv_cache

I added gpu memory usage prints before and after gpu_cache and cpu_cache alloc.

# vllm/worker/cache_engine.py CacheEngine init
        print_memory_usage("before allocate gpu cache")
        # Initialize the cache.
        self.gpu_cache = self._allocate_kv_cache(
            self.num_gpu_blocks, self.device_config.device_type)
        print_memory_usage("after allocate gpu cache")

        print_memory_usage("before allocate cpu cache")
        self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
        print_memory_usage("after allocate cpu cache")
        torch.cuda.empty_cache()

 # print_memory_usage func
 def print_memory_usage(info: str, sync: bool = True, empty_cache: bool = False, collect: bool = False):
    _MB = 1 << 20
    if sync:
        torch.cuda.synchronize()
    if empty_cache:
        torch.cuda.empty_cache()
    if collect:
        gc.collect()
    free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
    print(
        f"{info}: "
        f"use_gpu_memory: {(total_gpu_memory - free_gpu_memory)/_MB:.4f} MB, "
        f"free_gpu_memory: {free_gpu_memory/_MB:.4f} MB, "
        f"total_gpu_memory: {total_gpu_memory/_MB:.4f} MB"
    )

gpu usage print is:

INFO 08-07 07:36:04 gpu_executor.py:103] # GPU blocks: 7823, # CPU blocks: 512
before allocate gpu cache: use_gpu_memory: 15271.3750 MB, free_gpu_memory: 65779.2500 MB, total_gpu_memory: 81050.6250 MB
gpu kv_cache_shape: (2, 7823, 16, 32, 128)
after allocate gpu cache: use_gpu_memory: 77863.3750 MB, free_gpu_memory: 3187.2500 MB, total_gpu_memory: 81050.6250 MB
before allocate cpu cache: use_gpu_memory: 77863.3750 MB, free_gpu_memory: 3187.2500 MB, total_gpu_memory: 81050.6250 MB
cpu kv_cache_shape: (2, 512, 16, 32, 128)
after allocate cpu cache: use_gpu_memory: 77871.3750 MB, free_gpu_memory: 3179.2500 MB, total_gpu_memory: 81050.6250 MB

The gpu cache shape is (2, 7823, 16, 32, 128), (layer=32, type_size = bfloat16 / 8 = 2) , its gpu mem shoule be 32 2 7823 16 32 128 2 = 62584 MB. But the above printout shows that the memory before and after gpu_cache alloc is 15271.3750 and 77863.3750 MB, with a difference of 62592 > 62584. And there are memory changes before and after cpu_cache alloc(77863.3750 MB to 77871.3750 MB), but it's on the cpu, so theoretically, the gpu mem should be unchanged.

These strange gpu mem increases can further reduce the available space for model forward activation, resulting in OOM

Also, I separately micro-tested the gpu mem usage of _allocate_kv_cache using the following test_alloc_mem.py script

Click to expand test_alloc_mem.py ```python import torch from typing import List, Tuple import torch.nn.functional as F import gc def print_memory_usage(info: str, sync: bool = True, empty_cache: bool = False, collect: bool = False): get_info = True sync = True empty_cache = True collect = False _MB = 1 << 20 if sync: torch.cuda.synchronize() if empty_cache: torch.cuda.empty_cache() if collect: gc.collect() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() print( f"{info}: " f"use_gpu_memory: {(total_gpu_memory - free_gpu_memory)/_MB:.4f} MB, " f"free_gpu_memory: {free_gpu_memory/_MB:.4f} MB, " f"total_gpu_memory: {total_gpu_memory/_MB:.4f} MB" ) return (total_gpu_memory - free_gpu_memory) / _MB def test_alloc(device: str = "cuda"): kv_cache: List[torch.Tensor] = [] # key/value, num_blocks, block_size, num_heads, head_dim kv_cache_shape = (2, 7823, 16, 32, 128) # kv_cache_shape = (2, 7819, 16, 32, 128) dtype = torch.bfloat16 pin_memory = True if device == "cpu" else False # pin_memory = False for _ in range(32): # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. kv_cache.append( torch.zeros(kv_cache_shape, dtype=dtype, pin_memory=pin_memory, device=device)) return kv_cache def test(): print_memory_usage("before allocate cuda a") a = test_alloc("cuda") print_memory_usage("after allocate cuda a") print_memory_usage("before allocate cpu b") b = test_alloc("cpu") print_memory_usage("after allocate cpu b") test() ```

The result is the same as described above, there is also a strange increase in gpu memory.

My thoughts: For gpu_cache the occupied size is larger than the theoretical value, my guess is that there is some alignment strategy in torch's memory management that causes this. For cpu_cache causing gpu mem to increase in size, I can't understand it, but when I force pin_memory = False, gpu mem no longer increases in size. And when I reduce the num_blocks from 7823 to 7819, the gpu memory usage remains the same as 7823, which further suggests that there is some kind of memory alignment strategy in torch that compresses the available space for the activation, which is easy to happen oom when the gpu_memory_utilization is large.

Though, we can reduce gpu_memory_utilization to avoid oom, but that makes it difficult to maximize the use of gpu mem. Therefore, we may need to modify the determine_num_available_blocks or profile_run func to take these factors into account, to avoid oom, so that we can safely set gpu_memory_utilization=1.0 to fully utilize the gpu resources.

youkaichao commented 3 months ago

you may want to read some doc about https://zdevito.github.io/2022/08/04/cuda-caching-allocator.html .

TL;DR; never use gpu_memory_ utilization=1.0 . There are lots of factors that can take unexpected memory. You cannot control every MB of GPU memory you have.

izhuhaoran commented 3 months ago

you may want to read some doc about https://zdevito.github.io/2022/08/04/cuda-caching-allocator.html .

TL;DR; never use gpu_memory_ utilization=1.0 . There are lots of factors that can take unexpected memory. You cannot control every MB of GPU memory you have.

Thank you for your advice, I will look into the relevant content you have provided.