vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.75k stars 4.67k forks source link

[Bug]: LLM initialization time increases significantly with larger tensor parallel size and Ray #10283

Open piood opened 1 week ago

piood commented 1 week ago

Your current environment

vllm 0.5.2

The output of `python collect_env.py` ```text Collecting environment information... PyTorch version: 2.3.1+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A OS: Ubuntu 20.04.5 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: Could not collect CMake version: version 3.24.1 Libc version: glibc-2.31 Python version: 3.8.10 (default, Mar 13 2023, 10:26:41) [GCC 9.4.0] (64-bit runtime) Python platform: Linux-5.10.134-008.7.kangaroo.al8.x86_64-x86_64-with-glibc2.29 Is CUDA available: True CUDA runtime version: 12.1.66 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA L20Z GPU 1: NVIDIA L20Z GPU 2: NVIDIA L20Z GPU 3: NVIDIA L20Z GPU 4: NVIDIA L20Z GPU 5: NVIDIA L20Z GPU 6: NVIDIA L20Z GPU 7: NVIDIA L20Z Nvidia driver version: 535.161.08 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 52 bits physical, 57 bits virtual CPU(s): 100 On-line CPU(s) list: 0-99 Thread(s) per core: 1 Core(s) per socket: 100 Socket(s): 1 NUMA node(s): 1 Vendor ID: GenuineIntel CPU family: 6 Model: 143 Model name: Intel(R) Xeon(R) Processor Stepping: 8 CPU MHz: 2000.000 BogoMIPS: 4000.00 Hypervisor vendor: KVM Virtualization type: full L1d cache: 4.7 MiB L1i cache: 3.1 MiB L2 cache: 200 MiB L3 cache: 105 MiB NUMA node0 CPU(s): 0-99 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Vulnerable Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd avx512vbmi umip pku waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities Versions of relevant libraries: [pip3] flake8==6.1.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.22.2 [pip3] onnx==1.13.1 [pip3] pytorch-quantization==2.1.2 [pip3] torch==2.3.1 [pip3] torch-tensorrt==1.4.0.dev0 [pip3] torchaudio==2.3.1 [pip3] torchtext==0.13.0a0+fae8e8c [pip3] torchtyping==0.1.4 [pip3] torchvision==0.18.1 [pip3] triton==2.3.1 [conda] Could not collect ```

Model Input Dumps

just test the vllm init time

🐛 Describe the bug

Issue Description

We observed significant and unexpected increases in VLLM initialization time when scaling tensor parallelism (TP), especially with Ray enabled.

Observed Behavior

Expected Behavior

Initialization time should remain relatively constant or have minimal increase when scaling tensor parallelism and use ray.

Environment

Additional Context

The initialization time increase appears disproportionate to the tensor parallel size, suggesting a potential bottleneck in the initialization process, particularly when Ray is involved.

Reproducible Steps

  1. Run VLLM with TP=1
  2. Run VLLM with TP=4
  3. Run VLLM with TP=4 and Ray enabled
  4. Compare initialization times

vllm start time

def run_vllm(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: str,
    quantization: Optional[str],
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
    trust_remote_code: bool,
    dtype: str,
    max_model_len: Optional[int],
    enforce_eager: bool,
    kv_cache_dtype: str,
    quantization_param_path: Optional[str],
    device: str,
    enable_prefix_caching: bool,
    enable_chunked_prefill: bool,
    max_num_batched_tokens: int,
    distributed_executor_backend: Optional[str],
    gpu_memory_utilization: float = 0.9,
    num_scheduler_steps: int = 1,
    use_v2_block_manager: bool = False,
    download_dir: Optional[str] = None,
    load_format: str = EngineArgs.load_format,
    disable_async_output_proc: bool = False,
) -> float:
    # 导入必要的库
    from vllm import LLM, SamplingParams

    print(f"Start initializing LLM at {time.strftime('%Y-%m-%d %H:%M:%S')}")
    start = time.perf_counter()
    llm = LLM(
        model=model,
        tokenizer=tokenizer,
        quantization=quantization,
        tensor_parallel_size=tensor_parallel_size,
        seed=seed,
        trust_remote_code=trust_remote_code,
        dtype=dtype,
        max_model_len=max_model_len,
        gpu_memory_utilization=gpu_memory_utilization,
        enforce_eager=enforce_eager,
        kv_cache_dtype=kv_cache_dtype,
        quantization_param_path=quantization_param_path,
        device=device,
        enable_prefix_caching=enable_prefix_caching,
        download_dir=download_dir,
        enable_chunked_prefill=enable_chunked_prefill,
        max_num_batched_tokens=max_num_batched_tokens,
        distributed_executor_backend=distributed_executor_backend,
        load_format=load_format,
        # num_scheduler_steps=num_scheduler_steps,
        # use_v2_block_manager=use_v2_block_manager,
        # disable_async_output_proc=disable_async_output_proc,
    )
    end = time.perf_counter()
    print(f"Finish initializing LLM at {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"vllm init time: {end - start}")

vllm ray start time

def run_ray_vllm(
    requests: List[Tuple[str, int, int]],
    model: str,
    tokenizer: str,
    quantization: Optional[str],
    tensor_parallel_size: int,
    seed: int,
    n: int,
    use_beam_search: bool,
    trust_remote_code: bool,
    dtype: str,
    max_model_len: Optional[int],
    enforce_eager: bool,
    kv_cache_dtype: str,
    quantization_param_path: Optional[str],
    device: str,
    enable_prefix_caching: bool,
    enable_chunked_prefill: bool,
    max_num_batched_tokens: int,
    distributed_executor_backend: Optional[str],
    gpu_memory_utilization: float = 0.9,
    num_scheduler_steps: int = 1,
    use_v2_block_manager: bool = False,
    download_dir: Optional[str] = None,
    load_format: str = EngineArgs.load_format,
    disable_async_output_proc: bool = False,
) -> float:
    # 导入必要的库
    from vllm import LLM, SamplingParams

    import ray

    @ray.remote
    class LLMWorker:
        def __init__(self, model, tokenizer, quantization, tensor_parallel_size, seed, trust_remote_code, dtype, max_model_len, gpu_memory_utilization, enforce_eager, kv_cache_dtype, quantization_param_path, device, enable_prefix_caching, download_dir, enable_chunked_prefill, max_num_batched_tokens, distributed_executor_backend, load_format, num_scheduler_steps, use_v2_block_manager, disable_async_output_proc):
            from vllm import LLM
            start = time.perf_counter()
            self.llm = LLM(
                model=model,
                tokenizer=tokenizer,
                quantization=quantization,
                tensor_parallel_size=tensor_parallel_size,
                seed=seed,
                trust_remote_code=trust_remote_code,
                dtype=dtype,
                max_model_len=max_model_len,
                gpu_memory_utilization=gpu_memory_utilization,
                enforce_eager=enforce_eager,
                kv_cache_dtype=kv_cache_dtype,
                quantization_param_path=quantization_param_path,
                device=device,
                enable_prefix_caching=enable_prefix_caching,
                download_dir=download_dir,
                enable_chunked_prefill=enable_chunked_prefill,
                max_num_batched_tokens=max_num_batched_tokens,
                distributed_executor_backend=distributed_executor_backend,
                load_format=load_format,
                # num_scheduler_steps=num_scheduler_steps,
                # use_v2_block_manager=use_v2_block_manager,
                # disable_async_output_proc=disable_async_output_proc,
            )
            end = time.perf_counter()
            print(f"Finish initializing LLM at {time.strftime('%Y-%m-%d %H:%M:%S')}")
            print(f"vllm init time: {end - start}")

        def generate(self, prompts, sampling_params):
            return self.llm.generate(prompts, sampling_params, use_tqdm=True)

    # Create LLM worker
    worker = LLMWorker.remote(
        model=model,
        tokenizer=tokenizer,
        quantization=quantization,
        tensor_parallel_size=tensor_parallel_size,
        seed=seed,
        trust_remote_code=trust_remote_code,
        dtype=dtype,
        max_model_len=max_model_len,
        gpu_memory_utilization=gpu_memory_utilization,
        enforce_eager=enforce_eager,
        kv_cache_dtype=kv_cache_dtype,
        quantization_param_path=quantization_param_path,
        device=device,
        enable_prefix_caching=enable_prefix_caching,
        download_dir=download_dir,
        enable_chunked_prefill=enable_chunked_prefill,
        max_num_batched_tokens=max_num_batched_tokens,
        distributed_executor_backend=distributed_executor_backend,
        load_format=load_format,
        num_scheduler_steps=num_scheduler_steps,
        use_v2_block_manager=use_v2_block_manager,
        disable_async_output_proc=disable_async_output_proc,
    )

    # Add the requests to the engine.
    prompts: List[str] = []
    sampling_params: List[SamplingParams] = []
    for prompt, _, output_len in requests:
        prompts.append(prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=0.0 if use_beam_search else 1.0,
                top_p=1.0,
                use_beam_search=use_beam_search,
                ignore_eos=True,
                max_tokens=output_len,
            )
        )

    start = time.perf_counter()
    ray.get(worker.generate.remote(prompts, sampling_params))
    end = time.perf_counter()
    return end - start
AaronWang04 commented 1 week ago

someone correct me if im wrong but the way the workers are initialized are done sequentially on the main process. which can be seen in the function I linked below

https://github.com/vllm-project/vllm/blob/bbd3e86926f15e59e4c62246b4b3185e71fe7ff2/vllm/executor/ray_gpu_executor.py#L109

ray add additional overhead because you have to send the whole worker configs through Ray which is a slower process

piood commented 1 week ago

someone correct me if im wrong but the way the workers are initialized are done sequentially on the main process. which can be seen in the function I linked below

https://github.com/vllm-project/vllm/blob/bbd3e86926f15e59e4c62246b4b3185e71fe7ff2/vllm/executor/ray_gpu_executor.py#L109

ray add additional overhead because you have to send the whole worker configs through Ray which is a slower process

Thank you for your answer! However, I still have some concerns about the initialization overhead:

For a 7B model:

The overhead seems disproportionately large considering:

  1. The baseline initialization is only 7 seconds
  2. Moving from TP=1 to TP=4 doubles the initialization time
  3. Adding Ray introduces an additional 10s overhead, which is even larger than the TP scaling overhead

Is this level of overhead expected? It seems excessive for a 7B model, especially since:

Could there be potential optimization opportunities to reduce these initialization costs?

AaronWang04 commented 1 week ago

I don't find that overhead too strange, and there definitely is room for optimizations (parallelizing the process) but engine startup time is not really an important metric that people worry about. (model reloading would probably be the solution more people are interested in that is currently not implemented?) is there a reason you're looking for faster initialization?

Jack47 commented 1 week ago

I don't find that overhead too strange, and there definitely is room for optimizations (parallelizing the process) but engine startup time is not really an important metric that people worry about. (model reloading would probably be the solution more people are interested in that is currently not implemented?) is there a reason you're looking for faster initialization?

Great thanks for you relpy!

we want to improve the startup speed, IMHO, 34s is also too long to wait, especially when we are developing new features and what to run some tests to verify it.