vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.14k stars 4.73k forks source link

[Bug]: RuntimeError: shape mismatch: value tensor of shape [3328, 7168] cannot be broadcast to indexing result of shape [3328] for OpenGVLab/InternVL2-40B #8275

Closed Manikandan-Thangaraj-ZS0321 closed 2 months ago

Manikandan-Thangaraj-ZS0321 commented 2 months ago

Your current environment

The output of `python collect_env.py` ```text Collecting environment information... PyTorch version: 2.4.0+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A OS: Ubuntu 20.04.6 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.31 Python version: 3.12.5 (main, Aug 17 2024, 16:46:07) [GCC 9.4.0] (64-bit runtime) Python platform: Linux-5.15.0-113-generic-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090 Nvidia driver version: 535.183.01 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 12 On-line CPU(s) list: 0-11 Thread(s) per core: 2 Core(s) per socket: 6 Socket(s): 1 NUMA node(s): 1 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) W-2133 CPU @ 3.60GHz Stepping: 4 CPU MHz: 3600.000 CPU max MHz: 3900.0000 CPU min MHz: 1200.0000 BogoMIPS: 7200.00 Virtualization: VT-x L1d cache: 192 KiB L1i cache: 192 KiB L2 cache: 6 MiB L3 cache: 8.3 MiB NUMA node0 CPU(s): 0-11 Vulnerability Gather data sampling: Mitigation; Microcode Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable Vulnerability Meltdown: Mitigation; PTI Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable Vulnerability Retbleed: Mitigation; IBRS Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; IBRS; IBPB conditional; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Mitigation; Clear CPU buffers; SMT vulnerable Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd mba ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req md_clear flush_l1d arch_capabilities Versions of relevant libraries: [pip3] flashinfer==0.1.6+cu121torch2.4 [pip3] numpy==1.26.4 [pip3] nvidia-cublas-cu12==12.1.3.1 [pip3] nvidia-cuda-cupti-cu12==12.1.105 [pip3] nvidia-cuda-nvrtc-cu12==12.1.105 [pip3] nvidia-cuda-runtime-cu12==12.1.105 [pip3] nvidia-cudnn-cu12==9.1.0.70 [pip3] nvidia-cufft-cu12==11.0.2.54 [pip3] nvidia-curand-cu12==10.3.2.106 [pip3] nvidia-cusolver-cu12==11.4.5.107 [pip3] nvidia-cusparse-cu12==12.1.0.106 [pip3] nvidia-ml-py==12.560.30 [pip3] nvidia-nccl-cu12==2.20.5 [pip3] nvidia-nvjitlink-cu12==12.6.68 [pip3] nvidia-nvtx-cu12==12.1.105 [pip3] pyzmq==26.2.0 [pip3] torch==2.4.0 [pip3] torchvision==0.19.0 [pip3] transformers==4.44.2 [pip3] triton==3.0.0 [conda] Could not collect ROCM Version: Could not collect Neuron SDK Version: N/A vLLM Version: 0.6.0@ vLLM Build Flags: CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled GPU Topology: GPU0 CPU Affinity NUMA Affinity GPU NUMA ID GPU0 X 0-11 0 N/A Legend: X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks ```

🐛 Describe the bug

While serving the OpenGVLab/InternVL2-40B using Multi-Node Multi-GPU (tensor parallel plus pipeline parallel inference) facing these issue RuntimeError: shape mismatch: value tensor of shape [3328, 7168] cannot be broadcast to indexing result of shape [3328]

But I don't face these issue while serving the OpenGVLab/InternVL2-8B and OpenGVLab/InternVL2-26B

Command: vllm serve OpenGVLab/InternVL2-40B --tensor-parallel-size 1 --pipeline-parallel-size 4 --dtype bfloat16 --gpu-memory-utilization 0.9 --max-model-len 6000 --enforce-eager --trust-remo te-code --tokenizer-mode "auto"

Log:

INFO 09-08 10:17:29 api_server.py:495] vLLM API server version 0.6.0                                                                                                                                               
INFO 09-08 10:17:29 api_server.py:496] args: Namespace(model_tag='OpenGVLab/InternVL2-40B', config='', host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_meth
ods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, ro
ot_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_auto_tool_choice=False, tool_call_parser=None, model='OpenGVLab/InternVL2-40B', tokenizer=None, skip_
tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, download_dir=None, load_format='auto', config_format='auto', dtype='bfloat16', kv_
cache_dtype='auto', quantization_param_path=None, max_model_len=6000, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=4, tensor_parallel_size=1
, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=
4, cpu_offload_gb=10.0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, ro
pe_theta=None, enforce_eager=True, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=Non
e, limit_mm_per_prompt=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_
prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_q
uantization=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_look
up_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None,
 model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_p
roc=False, override_neuron_config=None, engine_use_ray=False, disable_log_requests=False, max_log_len=None, dispatch_function=<function serve at 0x7f9a44236200>)                                                  
INFO 09-08 10:17:31 api_server.py:162] Multiprocessing frontend to use ipc:///tmp/80d2a87d-aa22-4d9d-9e31-26a457cc7256 for RPC Path.                                                                               
INFO 09-08 10:17:31 api_server.py:178] Started engine process with PID 5081                                                                                                                                        
INFO 09-08 10:17:37 config.py:896] Defaulting to use ray for distributed inference                                                                                                                                 
WARNING 09-08 10:17:37 config.py:364] Async output processing can not be enabled with pipeline parallel                                                                                                            
2024-09-08 10:17:37,156 INFO worker.py:1598 -- Connecting to existing Ray cluster at address: 172.18.10.139:6380...                                                                                               
2024-09-08 10:17:37,193 INFO worker.py:1783 -- Connected to Ray cluster.                                                                                                                                           
INFO 09-08 10:17:37 llm_engine.py:213] Initializing an LLM engine (v0.6.0) with config: model='OpenGVLab/InternVL2-40B', speculative_config=None, tokenizer='OpenGVLab/InternVL2-40B', skip_tokenizer_init=False, t
okenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=6000, download_dir=None, loa
d_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=4, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_confi
g=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False)
, seed=0, served_model_name=OpenGVLab/InternVL2-40B, use_v2_block_manager=False, num_scheduler_steps=1, enable_prefix_caching=False, use_async_output_proc=False)                                                  
INFO 09-08 10:17:39 ray_gpu_executor.py:134] use_ray_spmd_worker: False                                                                                                                                            
INFO 09-08 10:17:58 utils.py:977] Found nccl from library libnccl.so.2                                                                                                                                             
INFO 09-08 10:17:58 pynccl.py:63] vLLM is using nccl==2.20.5                                                                                                                                                       
(RayWorkerWrapper pid=2043, ip=172.18.10.140) INFO 09-08 10:17:58 utils.py:977] Found nccl from library libnccl.so.2                                                                                              
(RayWorkerWrapper pid=2043, ip=172.18.10.140) INFO 09-08 10:17:58 pynccl.py:63] vLLM is using nccl==2.20.5                                                                                                        
INFO 09-08 10:17:59 model_runner.py:915] Starting to load model OpenGVLab/InternVL2-40B...                                                                                                                         
/usr/local/lib/python3.12/dist-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.l
ibrary.impl_abstract` in a future version of PyTorch.                                                                                                                                                              
  @torch.library.impl_abstract("xformers_flash::flash_fwd")                                                                                                                                                        
(RayWorkerWrapper pid=17129, ip=172.18.10.141) INFO 09-08 10:17:59 model_runner.py:915] Starting to load model OpenGVLab/InternVL2-40B...                                                                         
/usr/local/lib/python3.12/dist-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.l
ibrary.impl_abstract` in a future version of PyTorch.                                                                                                                                                              
  @torch.library.impl_abstract("xformers_flash::flash_bwd")                                                                                                                                                        
(RayWorkerWrapper pid=17129, ip=172.18.10.141) /usr/local/lib/python3.12/dist-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. 
Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.                                                                                                              
(RayWorkerWrapper pid=17129, ip=172.18.10.141)   @torch.library.impl_abstract("xformers_flash::flash_fwd")
(RayWorkerWrapper pid=17203, ip=172.18.10.141)   @torch.library.impl_abstract("xformers_flash::flash_bwd")
INFO 09-08 10:18:07 weight_utils.py:235] Using model weights format ['*.safetensors']
(RayWorkerWrapper pid=2043, ip=172.18.10.140) INFO 09-08 10:18:08 weight_utils.py:235] Using model weights format ['*.safetensors']
(RayWorkerWrapper pid=17203, ip=172.18.10.141) INFO 09-08 10:17:58 utils.py:977] Found nccl from library libnccl.so.2 [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disa
ble log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(RayWorkerWrapper pid=17203, ip=172.18.10.141) INFO 09-08 10:17:58 pynccl.py:63] vLLM is using nccl==2.20.5 [repeated 2x across cluster]
(RayWorkerWrapper pid=2043, ip=172.18.10.140) INFO 09-08 10:17:59 model_runner.py:915] Starting to load model OpenGVLab/InternVL2-40B... [repeated 2x across cluster]
Loading safetensors checkpoint shards:   0% Completed | 0/17 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  24% Completed | 4/17 [00:00<00:02,  4.34it/s]
Loading safetensors checkpoint shards:  47% Completed | 8/17 [00:01<00:02,  4.30it/s]
Loading safetensors checkpoint shards:  82% Completed | 14/17 [00:02<00:00,  8.11it/s]
Loading safetensors checkpoint shards: 100% Completed | 17/17 [00:02<00:00,  8.16it/s]

INFO 09-08 10:18:15 model_runner.py:926] Loading model weights took 16.9311 GB
(RayWorkerWrapper pid=2043, ip=172.18.10.140) INFO 09-08 10:18:17 model_runner.py:926] Loading model weights took 16.0756 GB
(RayWorkerWrapper pid=17203, ip=172.18.10.141) INFO 09-08 10:18:11 weight_utils.py:235] Using model weights format ['*.safetensors'] [repeated 2x across cluster]
(RayWorkerWrapper pid=17203, ip=172.18.10.141) INFO 09-08 10:18:24 model_runner.py:926] Loading model weights took 16.9311 GB
(RayWorkerWrapper pid=17129, ip=172.18.10.141) INFO 09-08 10:18:24 model_runner.py:926] Loading model weights took 16.0756 GB
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464] Error executing method determine_num_available_blocks. This might cause deadlock in distributed execution.
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464] Traceback (most recent call last):
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 456, in execute_method
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     return executor(*args, **kwargs)
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     return func(*args, **kwargs)
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]            ^^^^^^^^^^^^^^^^^^^^^
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 222, in determine_num_available_blocks
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     self.model_runner.profile_run()
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     return func(*args, **kwargs)
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]            ^^^^^^^^^^^^^^^^^^^^^
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1133, in profile_run
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     return func(*args, **kwargs)
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]            ^^^^^^^^^^^^^^^^^^^^^
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1450, in execute_model
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     hidden_or_intermediate_states = model_executable(
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]                                     ^^^^^^^^^^^^^^^^^
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     return self._call_impl(*args, **kwargs)
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     return forward_call(*args, **kwargs)
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/internvl.py", line 487, in forward
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     inputs_embeds = merge_multimodal_embeddings(
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 146, in merge_multimodal_embeddi
ngs
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     inputs_embeds[mask] = flattened
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464]     ~~~~~~~~~~~~~^^^^^^
(RayWorkerWrapper pid=2043, ip=172.18.10.140) ERROR 09-08 10:18:28 worker_base.py:464] **RuntimeError: shape mismatch: value tensor of shape [3328, 7168] cannot be broadcast to indexing result of shape [3328]**
Process SpawnProcess-1:                                                                                                                                                                                 [210/17369]
Traceback (most recent call last):
  File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 236, in run_rpc_server
    server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/rpc/server.py", line 34, in __init__
    self.engine = AsyncLLMEngine.from_engine_args(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 735, in from_engine_args
    engine = cls(
             ^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 615, in __init__
    self.engine = self._init_engine(*args, **kwargs) 
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 835, in _init_engine
    return engine_class(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/async_llm_engine.py", line 262, in __init__
    super().__init__(*args, **kwargs)
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 319, in __init__
    self._initialize_kv_caches()
  File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 448, in _initialize_kv_caches
    self.model_executor.determine_num_available_blocks())
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/executor/distributed_gpu_executor.py", line 39, in determine_num_available_blocks
    num_blocks = self._run_workers("determine_num_available_blocks", )
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/executor/ray_gpu_executor.py", line 421, in _run_workers
    ray_worker_outputs = ray.get(ray_worker_outputs) 
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
  File "/usr/local/lib/python3.12/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 2661, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/ray/_private/worker.py", line 871, in get_objects
    raise value.as_instanceof_cause()

Debug Info that I found:

Debug: input_ids shape: torch.Size([6000])
Debug: inputs_embeds shape: torch.Size([6000, 7168])
Debug: mask shape: torch.Size([6000]), num True values: 3328
Debug: flattened shape: torch.Size([3328, 7168])

Before submitting a new issue...

youkaichao commented 2 months ago

cc @ywang96 @DarkLight1337 @Isotr0py

looks like an issue from the visual modality part

Isotr0py commented 2 months ago

I think this is an issue related to pipeline parallelism, because I can reproduce this on InternVL2-4B with pp_size=2, while pp_size=1 and tp_size=2 have no issue.