vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.75k stars 4.67k forks source link

[Bug]: Running on a single machine with multiple GPUs error #9875

Open Wiselnn570 opened 3 weeks ago

Wiselnn570 commented 3 weeks ago

Your current environment

Name: vllm Version: 0.6.3.post2.dev171+g890ca360

Model Input Dumps

No response

🐛 Describe the bug

I used the interface from this vllm repository to load the model and ran eval scripts on vlmevalkit(https://github.com/open-compass/VLMEvalKit)

torchrun --nproc-per-node=8 run.py --data Video-MME --model Qwen2_VL-M-RoPE-80k

for evaluation, but I got the error

RuntimeError: world_size (8) is not equal to tensor_model_parallel_size (1) x pipeline_model_parallel_size (1). 

Could you please advise on how to resolve this? Here is the interface

from vllm import LLM
llm = LLM("/mnt/hwfile/mllm/weixilin/cache/Qwen2-VL-7B-Instruct", 
            max_model_len=100000,
            limit_mm_per_prompt={"video": 10},
            )

https://github.com/vllm-project/vllm/blob/3ea2dc2ec49d1ddd7875045e2397ae76a8f50b38/vllm/distributed/parallel_state.py#L1025 Seems that the error occur at this assertion, so what change should I make to fit the assertion, thanks.

Before submitting a new issue...

DarkLight1337 commented 3 weeks ago

vLLM has its own multiprocessing setup for TP/PP. You should avoid using torchrun with vLLM.

DarkLight1337 commented 3 weeks ago

cc @youkaichao

youkaichao commented 3 weeks ago

yeah we don't support torchrun , but it would be good to provide some scripts to run multiple vllm instances with a proxy erver using litellm .

youkaichao commented 3 weeks ago

@Wiselnn570 contribution welcome!

youkaichao commented 3 weeks ago

see https://docs.litellm.ai/docs/simple_proxy#load-balancing---multiple-instances-of-1-model for how to use litellm

Wiselnn570 commented 3 weeks ago

@youkaichao @DarkLight1337 Sure, I'm glad to contribute to this community when I have time! One more question, recently, I encountered an issue while modifying the positional encoding in the mrope_input_positions section of the Qwen2-VL code, and I try but don't know how to resolve it. In short, I'm aiming to explore the model's performance when extrapolating to a 60k context on the Qwen2-VL 7B model, using video data for testing. I tried replacing this section (https://github.com/vllm-project/vllm/blob/3bb4befea7166850bdee3f72fe060c9c4044ba85/vllm/worker/model_runner.py#L672) with vanilla-ROPE(That is, placing image, video, and text tokens all on the main diagonal of the M-RoPE.), which caused the max value of the mrope_input_positions up to approximately 59k, but it eventually led to an error.

../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [64,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [65,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [66,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [67,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [68,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [8320,0,0], thread: [69,0,0] Assertion `-sizes[i] <= index && in
dex < sizes[i] && "index out of bounds"` failed.
...
INFO 11-03 18:58:15 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20241103-185815.pkl...
WARNING 11-03 18:58:15 model_runner_base.py:143] Failed to pickle inputs of failed execution: CUDA error: device-side assert triggered
WARNING 11-03 18:58:15 model_runner_base.py:143] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
WARNING 11-03 18:58:15 model_runner_base.py:143] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
WARNING 11-03 18:58:15 model_runner_base.py:143] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
WARNING 11-03 18:58:15 model_runner_base.py:143] 
RuntimeError: Error in model execution: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I have already tested the original M-RoPE, which outputs correctly with a 60k context, and the maximum mrope_input_positions value is around 300. So, I am wondering if the position value is too large, causing it to exceed the index. How should I modify it to support vanilla-RoPE (Or perhaps some other 3D positional encoding, where the positional encoding values are quite large.) for evaluation? Thanks!

p.s. I noticed that this function (https://github.com/vllm-project/vllm/blob/3bb4befea7166850bdee3f72fe060c9c4044ba85/vllm/worker/model_runner.py#L637) was called several times before inferring on my provided video test data, and I’m wondering if this might be related.