Open movchan74 opened 14 hours ago
I encountered the same problem and cannot use v0.6.4 and v0.6.4.post1 on nodes that host multiple models.
I could define a gpu_memory_utilization config to make my models start in a specific order (for example 0.3 for the first, 0.7 for the second and 1 for the third if I have 3 models) but then it would break on restart if one of them is stopped or crashes in a different order...
It's an intentional change. Please see https://github.com/vllm-project/vllm/pull/9352#discussion_r1801771548 in particular.
cc @joerunde
it looks quite strange to me that allocated_bytes.all.peak
includes memory allocation for another process. is it a bug of pytorch? naturally I would assume this is the peak memory of the current process.
The memory allocation from other processes comes from :
total_allocated_bytes = torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]
Not from allocated_bytes.all.peak
, as I understand it, it contains the total memory allocated on the GPU regardless of the process.
@joerunde is OOO for a couple weeks.
I worked with him on some of the changes in https://github.com/vllm-project/vllm/pull/9352 and I don't think the fix needed to change the interpretation of gpu_memory_utilization
. One of the fixes in the PR is to record how much memory is allocated outside of PyTorch during the profiling (eg. due to NCCL) for more accurate accounting of the remaining free memory, that is why non_torch_allocations
is computed. The total_allocated_bytes
used to determine non_torch_allocations
could be computed against the baseline GPU memory from before the model is loaded to restore the previous meaning of gpu_memory_utilization
, eg.
total_allocated_bytes = self.init_gpu_memory - torch.cuda.mem_get_info()[0]
The total_allocated_bytes used to determine non_torch_allocations could be computed against the baseline GPU memory from before the model is loaded to restore the previous meaning of gpu_memory_utilization
@tjohnson31415 I think restoring the previous meaning of gpu_memory_utilization makes sense.
Your current environment
The output of `python collect_env.py`
```text Collecting environment information... PyTorch version: 2.5.1+cu124 Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A OS: Ubuntu 22.04.3 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35 Python version: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.0-69-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.3.107 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA RTX A6000 Nvidia driver version: 525.105.17 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 64 On-line CPU(s) list: 0-63 Vendor ID: AuthenticAMD Model name: AMD Ryzen Threadripper PRO 5975WX 32-Cores CPU family: 25 Model: 8 Thread(s) per core: 2 Core(s) per socket: 32 Socket(s): 1 Stepping: 2 Frequency boost: enabled CPU max MHz: 7006.6401 CPU min MHz: 1800.0000 BogoMIPS: 7186.96 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm Virtualization: AMD-V L1d cache: 1 MiB (32 instances) L1i cache: 1 MiB (32 instances) L2 cache: 16 MiB (32 instances) L3 cache: 128 MiB (4 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-63 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Versions of relevant libraries: [pip3] mypy==1.13.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] nvidia-cublas-cu12==12.4.5.8 [pip3] nvidia-cuda-cupti-cu12==12.4.127 [pip3] nvidia-cuda-nvrtc-cu12==12.4.127 [pip3] nvidia-cuda-runtime-cu12==12.4.127 [pip3] nvidia-cudnn-cu12==9.1.0.70 [pip3] nvidia-cufft-cu12==11.2.1.3 [pip3] nvidia-curand-cu12==10.3.5.147 [pip3] nvidia-cusolver-cu12==11.6.1.9 [pip3] nvidia-cusparse-cu12==12.3.1.170 [pip3] nvidia-ml-py==12.560.30 [pip3] nvidia-nccl-cu12==2.21.5 [pip3] nvidia-nvjitlink-cu12==12.4.127 [pip3] nvidia-nvtx-cu12==12.4.127 [pip3] onnx==1.17.0 [pip3] onnxruntime==1.20.0 [pip3] pytorch-lightning==2.4.0 [pip3] pytorch-metric-learning==2.7.0 [pip3] pyzmq==26.2.0 [pip3] sentence-transformers==3.3.1 [pip3] torch==2.5.1 [pip3] torch-audiomentations==0.11.1 [pip3] torch_pitch_shift==1.2.5 [pip3] torchaudio==2.5.1 [pip3] torchmetrics==1.6.0 [pip3] torchvision==0.20.1 [pip3] transformers==4.46.3 [pip3] triton==3.1.0 [conda] Could not collect ROCM Version: Could not collect Neuron SDK Version: N/A vLLM Version: 0.6.4.post1 vLLM Build Flags: CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled GPU Topology: GPU0 CPU Affinity NUMA Affinity GPU0 X 0-63 N/A Legend: X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks LD_LIBRARY_PATH=/root/.cache/pypoetry/virtualenvs/aana-episodic-retrieval-ZhswV3un-py3.10/lib/python3.10/site-packages/cv2/../../lib64: CUDA_MODULE_LOADING=LAZY ```Model Input Dumps
No response
🐛 Describe the bug
We encountered significant changes in the behavior of the
gpu_memory_utilization
parameter between vLLM 0.6.3 and vLLM 0.6.4. This change has introduced challenges in deploying multiple models on a shared GPU due to how memory usage is calculated and allocated.In vLLM 0.6.3,
gpu_memory_utilization
acted as a limit for the memory used by the model itself, making it easier to allocate resources for multiple models running on the same GPU. However, in vLLM 0.6.4, the behavior of this parameter was altered to act as a global GPU memory utilization limit. This change fundamentally impacts scenarios where multiple models are deployed, as it now accounts for all GPU memory usage, including memory allocated by other processes.The way
peak_memory
is calculated differs significantly between versions, leading to discrepancies in GPU memory management.Peak Memory Calculations
vLLM 0.6.3
In vLLM 0.6.3, the
peak_memory
is calculated as:init_gpu_memory
: Memory used by the GPU before vLLM starts.free_gpu_memory
: Memory free on the GPU after the model is loaded.Code Reference: worker.py#L231
Limitation: If other processes alter GPU memory usage during startup (e.g., starting multiple models concurrently), the calculated
peak_memory
becomes inaccurate, leading to errors like# GPU blocks: 0
.vLLM 0.6.4
In vLLM 0.6.4,
peak_memory
calculation was modified to include all GPU memory allocations:Code Reference: worker.py#L201
gpu_memory_utilization
into a global GPU memory limit, including memory usage by external processes.Looks like the change was introduced in this PR.
Use Case
We run multiple vLLM models on the same GPU, with each model having a defined memory limit using the
gpu_memory_utilization
parameter. This setup worked well in vLLM 0.6.3, allowing us to control memory allocation for individual models while efficiently utilizing GPU resources.With the behavior change in vLLM 0.6.4, our approach no longer works as intended. The new global memory limit forces us to reconsider our resource management strategy and introduces additional complexity in ensuring models coexist on the same GPU.
Expected Behavior (vLLM 0.6.3)
gpu_memory_utilization
limits memory usage for the model itself.Actual Behavior (vLLM 0.6.4)
gpu_memory_utilization
now acts as a global limit for the entire GPU memory usage, including other processes.gpu_memory_utilization
can no longer serve as a per-model memory limit.Proposed Solutions
gpu_memory_utilization
between versions. Users should be explicitly informed of this breaking change. I saw that the documentation now mentions the global memory limit, but it would be helpful to highlight that the behavior has changed compared to previous versions.Steps to Reproduce
gpu_memory_utilization
in vLLM 0.6.3. Observe that each model adheres to its memory limit.Additional Context
We understand that the change in memory calculation aims to address potential edge cases with non-Torch memory allocations. However, this has created a significant breaking change for users relying on the old behavior. A minor version update that introduces such a fundamental change feels unexpected and warrants further discussion.
Before submitting a new issue...