vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.71k stars 4.66k forks source link

[Bug]: multi-GPU inference (tensor_parallel_size=2) fails on Intel GPUs #6701

Open raffenet opened 4 months ago

raffenet commented 4 months ago

Your current environment

Collecting environment information...
WARNING 07-23 19:11:42 _custom_ops.py:14] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
PyTorch version: 2.1.0.post1+cxx11.abi
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: openSUSE Leap 15.4 (x86_64)
GCC version: (Spack GCC) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.1
Libc version: glibc-2.31

Python version: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.14.21-150400.24.100-default-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      52 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             208
On-line CPU(s) list:                0-207
Vendor ID:                          GenuineIntel
Model name:                         Intel (R) Xeon (R) CPU Max 9470C
CPU family:                         6
Model:                              143
Thread(s) per core:                 2
Core(s) per socket:                 52
Socket(s):                          2
Stepping:                           8
Frequency boost:                    enabled
CPU max MHz:                        2001.0000
CPU min MHz:                        800.0000
BogoMIPS:                           4000.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr avx512_fp16 amx_tile flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          4.9 MiB (104 instances)
L1i cache:                          3.3 MiB (104 instances)
L2 cache:                           208 MiB (104 instances)
L3 cache:                           210 MiB (2 instances)
NUMA node(s):                       4
NUMA node0 CPU(s):                  0-51,104-155
NUMA node1 CPU(s):                  52-103,156-207
NUMA node2 CPU(s):
NUMA node3 CPU(s):
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.1.30a0
[pip3] numpy==1.26.4
[pip3] torch==2.1.0.post1+cxx11.abi
[pip3] transformers==4.43.1
[pip3] triton==2.1.0
[conda] intel-extension-for-pytorch 2.1.30a0                 pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] torch                     2.1.0.post1+cxx11.abi          pypi_0    pypi
[conda] transformers              4.43.1                   pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.3.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect

Intel GPU info from sycl-ls

[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.3 [1.3.28202]
[ext_oneapi_level_zero:gpu:1] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.3 [1.3.28202]
[ext_oneapi_level_zero:gpu:2] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.3 [1.3.28202]
[ext_oneapi_level_zero:gpu:3] Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1550 1.3 [1.3.28202]

🐛 Describe the bug

offline_inference.py example crashes with tensor_parallel_size=2.

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

example output

WARNING 07-23 19:14:08 _custom_ops.py:14] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
INFO 07-23 19:14:11 config.py:715] Defaulting to use ray for distributed inference
2024-07-23 19:14:13,353 INFO worker.py:1788 -- Started a local Ray instance.
INFO 07-23 19:14:16 llm_engine.py:176] Initializing an LLM engine (v0.5.3.post1) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=xpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=facebook/opt-125m, use_v2_block_manager=False, enable_prefix_caching=False)
(pid=100697) WARNING 07-23 19:14:19 _custom_ops.py:14] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
Traceback (most recent call last):
  File "/home/raffenet/proj/vllm/examples/offline_inference.py", line 14, in <module>
    llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3.post1+xpu-py3.10.egg/vllm/entrypoints/llm.py", line 155, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3.post1+xpu-py3.10.egg/vllm/engine/llm_engine.py", line 441, in from_engine_args
    engine = cls(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3.post1+xpu-py3.10.egg/vllm/engine/llm_engine.py", line 251, in __init__
    self.model_executor = executor_class(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3.post1+xpu-py3.10.egg/vllm/executor/ray_xpu_executor.py", line 75, in __init__
    self._init_workers_ray(placement_group)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3.post1+xpu-py3.10.egg/vllm/executor/ray_xpu_executor.py", line 168, in _init_workers_ray
    worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3.post1+xpu-py3.10.egg/vllm/executor/ray_xpu_executor.py", line 323, in _run_workers
    return driver_worker_output + ray_worker_outputs
TypeError: can only concatenate tuple (not "list") to tuple
(pid=100897) WARNING 07-23 19:14:23 _custom_ops.py:14] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")

Manually printing the values being concatenated when the error occurs:

driver_worker_output: ('1e74c71e6f5268f65d01ec726894b8dc9910b8b191bad2cdbb2f6e15', [0])
ray_worker_outputs: [('1e74c71e6f5268f65d01ec726894b8dc9910b8b191bad2cdbb2f6e15', [1])]
raffenet commented 4 months ago

Also, if I hack the bad return value to be what I think it expected, I run into this backtrace later in the execution.

Traceback (most recent call last):
  File "/home/raffenet/proj/vllm/examples/offline_inference.py", line 17, in <module>
    outputs = llm.generate(prompts, sampling_params)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3+xpu-py3.10.egg/vllm/utils.py", line 838, in inner
    return fn(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3+xpu-py3.10.egg/vllm/entrypoints/llm.py", line 316, in generate
    outputs = self._run_engine(use_tqdm=use_tqdm)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3+xpu-py3.10.egg/vllm/entrypoints/llm.py", line 569, in _run_engine
    step_outputs = self.llm_engine.step()
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3+xpu-py3.10.egg/vllm/engine/llm_engine.py", line 911, in step
    output = self.model_executor.execute_model(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3+xpu-py3.10.egg/vllm/executor/distributed_gpu_executor.py", line 70, in execute_model
    self.parallel_worker_tasks = self._run_workers(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3+xpu-py3.10.egg/vllm/executor/ray_xpu_executor.py", line 312, in _run_workers
    driver_worker_output = self.driver_worker.execute_method(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3+xpu-py3.10.egg/vllm/worker/worker_base.py", line 383, in execute_method
    raise e
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.3+xpu-py3.10.egg/vllm/worker/worker_base.py", line 374, in execute_method
    return executor(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
TypeError: WorkerBase.start_worker_execution_loop() got an unexpected keyword argument 'async_run_tensor_parallel_workers_only'
raffenet commented 4 months ago

@jikunshang are these issues addressed in https://github.com/vllm-project/vllm/pull/5685?

jikunshang commented 4 months ago

@jikunshang are these issues addressed in #5685?

yes, I have fixed tensor parallel support issue, please try this PR.

raffenet commented 3 months ago

@jikunshang are these issues addressed in #5685?

yes, I have fixed tensor parallel support issue, please try this PR.

I have tested it on my system and it does indeed work with tp>1. Thanks! I hope it can be merged and made available in a future release.

raffenet commented 3 months ago

@jikunshang another bit of info. Running llama-2-7b with tensor parallel 2 and 4 works on my system, but on the same system trying to running llama-3-8b with with tp=2 results in this error. Is there anything I should try?

Traceback (most recent call last):
  File "/home/raffenet/proj/ipex-vllm/benchmark-scripts/offline_inference.py", line 87, in <module>
    llm = LLM(model=args.model,
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/entrypoints/llm.py", line 156, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/engine/llm_engine.py", line 444, in from_engine_args
    engine = cls(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/engine/llm_engine.py", line 264, in __init__
    self._initialize_kv_caches()
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/engine/llm_engine.py", line 363, in _initialize_kv_caches
    self.model_executor.determine_num_available_blocks())
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/executor/distributed_gpu_executor.py", line 38, in determine_num_available_blocks
    num_blocks = self._run_workers("determine_num_available_blocks", )
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/executor/ray_gpu_executor.py", line 371, in _run_workers
    self.driver_worker.execute_method(method, *driver_args,
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/worker/worker_base.py", line 382, in execute_method
    raise e
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/worker/worker_base.py", line 373, in execute_method
    return executor(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/worker/xpu_worker.py", line 129, in determine_num_available_blocks
    self.model_runner.profile_run()
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/worker/xpu_model_runner.py", line 223, in profile_run
    self.execute_model(model_input, kv_caches, intermediate_tensors)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/worker/xpu_model_runner.py", line 375, in execute_model
    hidden_or_intermediate_states = model_executable(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/model_executor/models/llama.py", line 420, in forward
    model_output = self.model(input_ids, positions, kv_caches,
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/model_executor/models/llama.py", line 320, in forward
    hidden_states, residual = layer(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/model_executor/models/llama.py", line 243, in forward
    hidden_states = self.self_attn(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/model_executor/models/llama.py", line 172, in forward
    q, k = self.rotary_emb(positions, q, k)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/model_executor/custom_op.py", line 13, in forward
    return self._forward_method(*args, **kwargs)
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/model_executor/layers/rotary_embedding.py", line 243, in forward_xpu
    ops.rotary_embedding(positions, query, key, self.head_size,
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/vllm-0.5.2+xpu-py3.10.egg/vllm/_ipex_ops.py", line 158, in rotary_embedding
    ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/intel_extension_for_pytorch/llm/functional/fusions.py", line 47, in rotary_embedding
    return RotaryEmbedding.apply_function(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/intel_extension_for_pytorch/llm/modules/mha_fusion.py", line 119, in apply_function
    query, key = runtime_module.rotary_embedding(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/intel_extension_for_pytorch/transformers/models/xpu/fusions/mha_fusion.py", line 79, in rotary_embedding
    torch.ops.torch_ipex.apply_rotary_embedding_half_qk(
  File "/home/raffenet/.conda/envs/vllm/lib/python3.10/site-packages/torch/_ops.py", line 692, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: The size of tensor a (2) must match the size of tensor b (8) at non-singleton dimension 2
jikunshang commented 3 months ago

@raffenet Thanks for your evaluation. Current ipex kernel do not support GQA yet, which is widely used in latest models like llama3-8b, llama-2-70b. (I am not certain about llama-2-13b) We have verified GQA functionality with an internal version ipex. Next ipex release will fix this issue and should be end of this month.

zhouyuan commented 3 months ago

CC @jgong5 @rogerxfeng8

rogerxfeng8 commented 3 months ago

GQA support for vLLM will be available in coming 2.3.110 IPEX release.

raffenet commented 3 months ago

GQA support for vLLM will be available in coming 2.3.110 IPEX release.

Thanks! I look forward to trying it.

liuxingbin commented 3 months ago

Hi, When I run vllm-xpu with Qwen2. I met the same error with GQA.

  File "/workspace/vllm/vllm/_ipex_ops.py", line 158, in rotary_embedding
    ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
  File "/usr/local/lib/python3.10/dist-packages/intel_extension_for_pytorch/llm/functional/fusions.py", line 47, in rotary_embedding
    return RotaryEmbedding.apply_function(
  File "/usr/local/lib/python3.10/dist-packages/intel_extension_for_pytorch/llm/modules/mha_fusion.py", line 119, in apply_function
    query, key = runtime_module.rotary_embedding(
  File "/usr/local/lib/python3.10/dist-packages/intel_extension_for_pytorch/transformers/models/xpu/fusions/mha_fusion.py", line 79, in rotary_embedding
    torch.ops.torch_ipex.apply_rotary_embedding_half_qk(
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 692, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: The size of tensor a (2) must match the size of tensor b (14) at non-singleton dimension 2

I am wondering when will the 2.3.110 IPEX be released.

jikunshang commented 2 months ago

ipex 2.3 is released and related code change is merged in vllm main branch, issues in this thread should be resolved. Please take a try thanks!

laserkelvin commented 1 month ago

I don't want to hijack this issue, but I'm facing the same issue as in the title. Running this on tag v0.6.2 but using the current main branch Dockerfile.xpu to build (to make sure IPEX 2.3 is being used) is still failing on my end. The build works, but the following command fails in a docker compose:

services:
  vllm-server:
    build:
      context: ./vllm
      dockerfile: Dockerfile.xpu
    container_name: vllm-server
    ports:
      - "8000:8000"
    restart: unless-stopped
    devices:
      - /dev/dri:/dev/dri
    volumes:
      - /dev/dri/by-path:/dev/dri/by-path
    network_mode: host
    ipc: host
    shm_size: 17179869184  # 16 GB
    command: --model Qwen/Qwen2.5-72B-Instruct --device xpu --tensor-parallel-size 2 

The error I'm getting:

INFO 10-02 23:12:43 api_server.py:164] Multiprocessing frontend to use ipc:///tmp/e1bb2060-ef4c-401c-b46f-661022722fe5 for IPC Path.
INFO 10-02 23:12:43 api_server.py:177] Started engine process with PID 76
INFO 10-02 23:12:44 config.py:899] Defaulting to use mp for distributed inference
WARNING 10-02 23:12:44 config.py:376] Async output processing is only supported for CUDA or TPU. Disabling it for other platforms.
WARNING 10-02 23:12:44 _custom_ops.py:18] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
Process SpawnProcess-1:
INFO 10-02 23:12:45 config.py:899] Defaulting to use mp for distributed inference
WARNING 10-02 23:12:45 config.py:376] Async output processing is only supported for CUDA or TPU. Disabling it for other platforms.
ERROR 10-02 23:12:45 llm_engine.py:530] Both start methods (spawn and fork) have issue on XPU if you use mp backend, Please try ray instead.
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/workspace/vllm/vllm/engine/multiprocessing/engine.py", line 388, in run_mp_engine
    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
  File "/workspace/vllm/vllm/engine/multiprocessing/engine.py", line 136, in from_engine_args
    executor_class = LLMEngine._get_executor_cls(engine_config)
  File "/workspace/vllm/vllm/engine/llm_engine.py", line 550, in _get_executor_cls
    return executor_class
UnboundLocalError: local variable 'executor_class' referenced before assignment

Snippet of pip freeze from within the image:

pip3 freeze | grep intel
intel-cmplr-lib-rt==2024.2.1
intel-cmplr-lib-ur==2024.2.1
intel-cmplr-lic-rt==2024.2.1
intel-opencl-rt==2024.2.1
intel-openmp==2024.2.1
intel-sycl-rt==2024.2.1
intel_extension_for_pytorch==2.3.110+xpu
jikunshang commented 1 month ago

I don't want to hijack this issue, but I'm facing the same issue as in the title. Running this on tag v0.6.2 but using the current main branch Dockerfile.xpu to build (to make sure IPEX 2.3 is being used) is still failing on my end. The build works, but the following command fails in a docker compose:

services:
  vllm-server:
    build:
      context: ./vllm
      dockerfile: Dockerfile.xpu
    container_name: vllm-server
    ports:
      - "8000:8000"
    restart: unless-stopped
    devices:
      - /dev/dri:/dev/dri
    volumes:
      - /dev/dri/by-path:/dev/dri/by-path
    network_mode: host
    ipc: host
    shm_size: 17179869184  # 16 GB
    command: --model Qwen/Qwen2.5-72B-Instruct --device xpu --tensor-parallel-size 2 

The error I'm getting:

INFO 10-02 23:12:43 api_server.py:164] Multiprocessing frontend to use ipc:///tmp/e1bb2060-ef4c-401c-b46f-661022722fe5 for IPC Path.
INFO 10-02 23:12:43 api_server.py:177] Started engine process with PID 76
INFO 10-02 23:12:44 config.py:899] Defaulting to use mp for distributed inference
WARNING 10-02 23:12:44 config.py:376] Async output processing is only supported for CUDA or TPU. Disabling it for other platforms.
WARNING 10-02 23:12:44 _custom_ops.py:18] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
Process SpawnProcess-1:
INFO 10-02 23:12:45 config.py:899] Defaulting to use mp for distributed inference
WARNING 10-02 23:12:45 config.py:376] Async output processing is only supported for CUDA or TPU. Disabling it for other platforms.
ERROR 10-02 23:12:45 llm_engine.py:530] Both start methods (spawn and fork) have issue on XPU if you use mp backend, Please try ray instead.
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/workspace/vllm/vllm/engine/multiprocessing/engine.py", line 388, in run_mp_engine
    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
  File "/workspace/vllm/vllm/engine/multiprocessing/engine.py", line 136, in from_engine_args
    executor_class = LLMEngine._get_executor_cls(engine_config)
  File "/workspace/vllm/vllm/engine/llm_engine.py", line 550, in _get_executor_cls
    return executor_class
UnboundLocalError: local variable 'executor_class' referenced before assignment

Snippet of pip freeze from within the image:

pip3 freeze | grep intel
intel-cmplr-lib-rt==2024.2.1
intel-cmplr-lib-ur==2024.2.1
intel-cmplr-lic-rt==2024.2.1
intel-opencl-rt==2024.2.1
intel-openmp==2024.2.1
intel-sycl-rt==2024.2.1
intel_extension_for_pytorch==2.3.110+xpu

thanks for reporting this, please try with this PR https://github.com/vllm-project/vllm/pull/8884. or just set distributed_executor_backend to ray, default value is mp which is not supported.