vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
29.12k stars 4.35k forks source link

[Feature]: Support distributing serving with KubeRay's autoscaler #3522

Open TrafalgarZZZ opened 7 months ago

TrafalgarZZZ commented 7 months ago

🚀 The feature, motivation and pitch

Hi, I'm deploying vLLM distributed serving in a Kubernetes environment. To make it work, I installed KubeRay to help me manage the ray cluster in Kubernetes. vLLM works well when the ray cluster has enough GPU resources. For example, if ray status reports that there are 2 GPUs available now, then vLLM launches successfully with the following command:

python -m vllm.entrypoints.openai.api_server  --trust-remote-code --model /root/vllm-models/ --gpu-memory-utilization 0.95 --tensor-parallel-size 2

I also noticed that KubeRay supports AutoScaling, so I would like to leverage the AutoScaling feature to save me more money on GPU instances.

What I expect is that when there are no more GPUs available in the (kube)ray cluster, launching the vLLM should trigger scaling out some ray worker pods with some available GPU inside it, and wait for ray cluster to schedule its RayLLMWorker actors. I failed with the following message:

$ python -m vllm.entrypoints.openai.api_server  --trust-remote-code --model /root/vllm-models/ --gpu-memory-utilization 0.95 --tensor-parallel-size 2
INFO 03-20 03:04:56 api_server.py:727] args: Namespace(host=None, port=8000, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], served_model_name=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, model='/root/vllm-models/', tokenizer=None, revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, download_dir=None, load_format='auto', dtype='auto', max_model_len=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=2, max_parallel_loading_workers=None, block_size=16, seed=0, swap_space=4, gpu_memory_utilization=0.95, max_num_batched_tokens=None, max_num_seqs=256, max_paddings=256, disable_log_stats=False, quantization=None, enforce_eager=False, max_context_len_to_capture=8192, engine_use_ray=False, disable_log_requests=False, max_log_len=None)
2024-03-20 03:04:56,828 INFO worker.py:1405 -- Using address 127.0.0.1:6379 set in the environment variable RAY_ADDRESS
2024-03-20 03:04:56,828 INFO worker.py:1540 -- Connecting to existing Ray cluster at address: 10.32.0.175:6379...
2024-03-20 03:04:56,832 INFO worker.py:1724 -- Connected to Ray cluster.
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 737, in <module>
    engine = AsyncLLMEngine.from_engine_args(engine_args)
  File "/usr/local/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 497, in from_engine_args
    placement_group = initialize_cluster(parallel_config,
  File "/usr/local/lib/python3.10/site-packages/vllm/engine/ray_utils.py", line 112, in initialize_cluster
    raise ValueError(
ValueError: The number of required GPUs exceeds the total number of available GPUs in the cluster.

It looks like vLLM eagerly checks the available GPU resources at start-up time and fail fast, which makes it not possible to leverage KubeRay's autoscaling feature.

Alternatives

I made some tries on the problem. One simple solution is to delete the eager check on ray cluster's current available resources. For example, my code looks like:

def initialize_cluster(
    parallel_config: ParallelConfig,
    engine_use_ray: bool = False,
    ray_address: Optional[str] = None,
) -> Tuple[str, Optional["PlacementGroup"]]:
    """Initialize the distributed cluster probably with Ray.

    Args:
        parallel_config: The configurations for parallel execution.
        engine_use_ray: Whether to use Ray for async engine.
        ray_address: The address of the Ray cluster. If None, uses
            the default Ray cluster address.

    Returns:
        A tuple of (`distributed_init_method`, `placement_group`). The
        `distributed_init_method` is the address for initializing the
        distributed backend. `placement_group` includes the specification
        of the resources for each distributed worker.
    """
    if parallel_config.worker_use_ray or engine_use_ray:
        if ray is None:
            raise ImportError(
                "Ray is not installed. Please install Ray to use distributed "
                "serving.")
        # Connect to a ray cluster.
        if is_hip():
            ray.init(address=ray_address,
                     ignore_reinit_error=True,
                     num_gpus=parallel_config.world_size)
        else:
            ray.init(address=ray_address, ignore_reinit_error=True)

    if not parallel_config.worker_use_ray:
        assert parallel_config.world_size == 1, (
            "Ray is required if parallel_config.world_size > 1.")
        return None

    # Create placement group for worker processes
    current_placement_group = ray.util.get_current_placement_group()
    if current_placement_group:
        # We are in a placement group
        bundles = current_placement_group.bundle_specs
        # Verify that we can use the placement group.
        gpu_bundles = 0
        for bundle in bundles:
            bundle_gpus = bundle.get("GPU", 0)
            if bundle_gpus > 1:
                raise ValueError(
                    "Placement group bundle cannot have more than 1 GPU.")
            if bundle_gpus:
                gpu_bundles += 1
        if parallel_config.world_size > gpu_bundles:
            raise ValueError(
                "The number of required GPUs exceeds the total number of "
                "available GPUs in the placement group.")
    else:
+     # num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
+     # if parallel_config.world_size > num_gpus_in_cluster:
+     #     raise ValueError(
+     #         "The number of required GPUs exceeds the total number of "
+     #         "available GPUs in the cluster.")
        # Create a new placement group
        placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
        current_placement_group = ray.util.placement_group(
            placement_group_specs)
        # Wait until PG is ready - this will block until all
        # requested resources are available, and will timeout
        # if they cannot be provisioned.
        ray.get(current_placement_group.ready(), timeout=1800)

    return current_placement_group

With the code above, it works well. But I think that's just a quick workaround, not a proper solution.

Additional context

No response

simon-mo commented 7 months ago

I think this might just be the right solution for KubeRay autoscaling. But let's cross check with @pcmoritz and @Yard1

baughmann commented 3 months ago

@simon-mo @pcmoritz @Yard1 Whats up with this? Its really not fun to have to tear down and re-build the engine every time the cluster needs to be resized. Even using Ray Serve requires the GPU count to be known at startup

Maybe there should be an engine arg to disable eager GPU count checks? Maybe also functions to add/remove GPUs?

simon-mo commented 3 months ago

@rkooo567

edoakes commented 3 months ago

The recommended pattern for this is to use Ray Serve with its replica placement group options. You can see it in the code sample in the link above:

    return VLLMDeployment.options(
        placement_group_bundles=pg_resources, placement_group_strategy="STRICT_PACK"
    ).bind(
        engine_args,
        parsed_args.response_role,
        parsed_args.lora_modules,
        parsed_args.chat_template,
    )

Here, we are telling Ray Serve what resources are required to schedule the replica including any actors it may start (for vLLM, the GPU workers). Ray Serve will allocate the specified placement group and only schedule the replica actor (for vLLM, the code running the engine). Therefore vLLM won't be started until the requisite GPUs are available and won't bump into the issue above.

sunmac commented 2 months ago

The recommended pattern for this is to use Ray Serve with its replica placement group options. You can see it in the code sample in the link above:

    return VLLMDeployment.options(
        placement_group_bundles=pg_resources, placement_group_strategy="STRICT_PACK"
    ).bind(
        engine_args,
        parsed_args.response_role,
        parsed_args.lora_modules,
        parsed_args.chat_template,
    )

Here, we are telling Ray Serve what resources are required to schedule the replica including any actors it may start (for vLLM, the GPU workers). Ray Serve will allocate the specified placement group and only schedule the replica actor (for vLLM, the code running the engine). Therefore vLLM won't be started until the requisite GPUs are available and won't bump into the issue above.

Is it possible to merge the vllm ? I'm hoping for official support for vllm with Ray Serve.