vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.05k stars 3.82k forks source link

[Bug]: Runtime Error: GET was unable to find an engine to execute this computation for LLaVa-NEXT #5465

Closed XkunW closed 2 months ago

XkunW commented 2 months ago

Your current environment

PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.29.5
Libc version: glibc-2.27

Python version: 3.10.12 (main, Jul 19 2023, 10:44:52) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.15.0-213-generic-x86_64-with-glibc2.27
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              12
On-line CPU(s) list: 0-11
Thread(s) per core:  1
Core(s) per socket:  12
Socket(s):           1
NUMA node(s):        1
Vendor ID:           GenuineIntel
CPU family:          6
Model:               61
Model name:          Intel Core Processor (Broadwell)
Stepping:            2
CPU MHz:             2095.076
BogoMIPS:            4190.15
Hypervisor vendor:   KVM
Virtualization type: full
L1d cache:           32K
L1i cache:           32K
L2 cache:            4096K
L3 cache:            16384K
NUMA node0 CPU(s):   0-11
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] transformers==4.41.2
[pip3] triton==2.3.0
[pip3] vllm_nccl_cu12==2.18.1.0.4.0
[conda] blas                      1.0                         mkl  
[conda] mkl                       2019.4                      243  
[conda] mkl-service               2.0.2            py36h7b6447c_0  
[conda] mkl_fft                   1.0.12           py36ha843d7b_0  
[conda] mkl_random                1.0.2            py36hd81dba3_0  
[conda] numpy                     1.16.4           py36h7e9f1db_0  
[conda] numpy-base                1.16.4           py36hde5b4d6_0  
[conda] numpydoc                  0.9.1                      py_0  
[conda] torch-scatter             2.0.7                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.0
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect

🐛 Describe the bug

I'm trying to use a single A40 GPU to host llava-hf/llava-v1.6-mistral-7b-hf using the OpenAI compatible server with vllm 0.5.0. Here is the command I used to launch the server:

python3 -m vllm.entrypoints.openai.api_server \
    --model /llava-v1.6-mistral-7b \
    --host "0.0.0.0" \
    --port 8080 \
    --tensor-parallel-size 1 \
    --dtype auto \
    --load-format safetensors \
    --image-input-type pixel_values\
    --image-token-id 32000 \
    --image-input-shape 1,3,336,336  \
    --image-feature-size 576 \
    --chat-template template_llava.jinja

And I got this error after loading the model weights:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/pkgs/python-3.10.12/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]:     return _run_code(code, main_globals, None,
[rank0]:   File "/pkgs/python-3.10.12/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]:     exec(code, run_globals)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 196, in <module>
[rank0]:     engine = AsyncLLMEngine.from_engine_args(
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 395, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 349, in __init__
[rank0]:     self.engine = self._init_engine(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 470, in _init_engine
[rank0]:     return engine_class(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 236, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 313, in _initialize_kv_caches
[rank0]:     self.model_executor.determine_num_available_blocks())
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 75, in determine_num_available_blocks
[rank0]:     return self.driver_worker.determine_num_available_blocks()
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/worker/worker.py", line 154, in determine_num_available_blocks
[rank0]:     self.model_runner.profile_run()
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 833, in profile_run
[rank0]:     self.execute_model(seqs, kv_caches)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 738, in execute_model
[rank0]:     hidden_states = model_executable(
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/model_executor/models/llava_next.py", line 398, in forward
[rank0]:     vision_embeddings = self._process_image_input(image_input)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/model_executor/models/llava_next.py", line 327, in _process_image_input
[rank0]:     image_features = self._process_image_pixels(image_input)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/model_executor/models/llava_next.py", line 317, in _process_image_pixels
[rank0]:     stacked_image_features = self._image_pixels_to_features(
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/vllm/model_executor/models/llava_next.py", line 232, in _image_pixels_to_features
[rank0]:     image_outputs = vision_tower(pixel_values.to(vision_tower.device),
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 926, in forward
[rank0]:     return self.vision_model(
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 850, in forward
[rank0]:     hidden_states = self.embeddings(pixel_values)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line 185, in forward
[rank0]:     patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
[rank0]:     return self._conv_forward(input, self.weight, self.bias)
[rank0]:   File "/fs01/projects/aieng/public/mixtral_vllm_env/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
[rank0]:     return F.conv2d(input, weight, bias, self.stride,
[rank0]: RuntimeError: GET was unable to find an engine to execute this computation
DarkLight1337 commented 2 months ago

Are you using a local model?

python3 -m vllm.entrypoints.openai.api_server \
    --model /llava-v1.6-mistral-7b \

I suspect that you have a typo in the model name.

XkunW commented 2 months ago

Are you using a local model?

python3 -m vllm.entrypoints.openai.api_server \
    --model /llava-v1.6-mistral-7b \

I suspect that you have a typo in the model name.

Sorry I made a typo in my description, it should be /llava-v1.6-mistral-7b-hf, which is what I have in my actual launching command, and yes it's a local model.

DarkLight1337 commented 2 months ago

Are you using vLLM w/ CPU (you may have to build from source)? I haven't tested LLaVA using CPU yet so there may be some incompatibilities. Try setting --dtype float?

XkunW commented 2 months ago

Are you using vLLM w/ CPU (you may have to build from source)? I haven't tested LLaVA using CPU yet so there may be some incompatibilities. Try setting --dtype float?

I'm using a single A40 GPU, and setting --dtype float did help it go beyond this error, I just need to figure out the right image feature size now, thanks!