Open jradikk opened 2 months ago
Try setting nvidia.com/gpu: 1
instead of nvidia.com/gpu.shared: 1
?
Unfortunately, I keep getting the same error even though ray cluster shows available GPUs
I noticed, that if i set ray_actor_options.num_gpus: 0.95
, vllm will show that i have 0.05 GPU.
ValueError: Current node has no GPU available. current_node_resource={'memory': 36000000000.0, 'GPU': 0.05....}.
Meaning, Ray for some reason subtracts the ray_actor_options.num_gpus value from the actual number of GPUs and passes that to vllm. Which in its turn means, that if i set ray_actor_options.num_gpus
to 1, which i did, it'll return 0 to vllm and it won't see a single GPU. Which is what seems to be happening. If all that is true, why is this happening?
Setting ray_actor_options.num_gpus
to 2 makes the whole application just hang and do nothing, while setting it to 0 returns an error that CUDA_VISIBLE_DEVICES
isn't set and therefore there's no GPUs.
Is it impossible to run distributed inference on a set of nodes with each of it having just 1 GPU?
@jradikk we have an existing example with vLLM that I've tested in the past: https://github.com/ray-project/kuberay/blob/master/ray-operator/config/samples/vllm/ray-service.vllm.yaml
I ran into issues with ray_actor_options.num_gpus
as well. It seems like setting tensor-parallel-size
gets things working as expected. I think vLLM automatically allocates GPUs based on this flag. See https://github.com/ray-project/kuberay/blob/master/ray-operator/config/samples/vllm/serve.py and https://github.com/ray-project/kuberay/blob/master/ray-operator/config/samples/vllm/ray-service.vllm.yaml#L22 for examples
@andrewsykim
If i don't specify num_gpus
in the ray_actor_options
, I get CUDA_VISIBLE_DEVICES is set to empty string
error. If i set it to 1
and set TENSOR_PARALLELISM=1
, PIPELINE_PARALLEL_SIZE=2
I get the error I started with - vLLM engine cannot start without GPU.
If i set TENSOR_PARALLELISM=2
I get the error of requesting more resources than the placement group has.
To recap, I have 2 nodes with 1 GPU each. Each of those GPUs is too small to load the model by itself, hence all I want is to run 2 replicas of vllm over ray to split the weights (PIPELINE_PARALLEL_SIZE
config) between the two. But I keep getting said errors. Am I misunderstanding the configs?
Ah thanks for clarfying, the example I linked uses tensor parallel size of 2 but on a single GPU node. It doesn't use parallel pipeline to do distributed inference across multiple nodes.
The guide for distributed serving (https://docs.vllm.ai/en/latest/serving/distributed_serving.html) mentions that pipeline parallelism is a beta feature, it might be better to open an issue in vLLM instead?
The docs also reference a sanity test script you can run:
@jradikk if you're able to share the full RayService YAML and source code you used, I'm happy to try to test this myself.
It would be great to add an example of vLLM using pipeline parallemism in the kuberay repo
it might be better to open an issue in vLLM instead?
It was sort of a dilemma for me of where to open an issue :) I ended up here, because if i simply run python3 -m vllm.entrypoints.openai.api_server --port 8080 --served-model-name llama3.1:70b --model hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4 --max-model-len 4096 --tokenizer hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4 --dtype half -q marlin --tensor-parallel-size 1 --pipeline-parallel-size 2
in a ray (non-kuberay) cluster, it works. It works as a Job, not a Serve though, but works. Such command splits the model, it gets loaded into 2 different GPUs at 2 nodes.
But using the same config with Kuberay, I'm unable to run it. I can theoretically stick with a self-provisioned Ray cluster, but using KubeRay is definitely a preferred method.
if you're able to share the full RayService YAML and source code you used
I did in the original post in this issue. It contains both the application and RayServe config. Although, if you need me to post it in some other format, just let me know. I'm definitely out of any ideas of how to make it work by myself :)
Can you share the full vllm_serve.py
? Or is that included in the vllm-openai image already? Also, what GPU type did you use in your testing?
@andrewsykim oh, right, I forgot I truncated it. Here you go:
import json
from typing import AsyncGenerator
from fastapi import BackgroundTasks
from starlette.requests import Request
from starlette.responses import StreamingResponse, Response, JSONResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from ray import serve
import os
import logging
from huggingface_hub import login
# Environment and configuration setup
logger = logging.getLogger("ray.serve")
@serve.deployment(name="llama3.1-deployment")
class VLLMDeployment:
def __init__(self, **kwargs):
hf_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
logger.info(f"token: {hf_token=}")
if not hf_token:
raise ValueError("HUGGING_FACE_HUB_TOKEN environment variable is not set")
login(token=hf_token)
logger.info("Successfully logged in to Hugging Face Hub")
args = AsyncEngineArgs(
model=os.getenv("MODEL_ID", "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4"), # Model identifier from Hugging Face Hub or local path.
dtype="half", # Automatically determine the data type (e.g., float16 or float32) for model weights and computations.
gpu_memory_utilization=float(os.getenv("GPU_MEMORY_UTILIZATION", "0.9")), # Percentage of GPU memory to utilize, reserving some for overhead.
max_model_len=int(os.getenv("MAX_MODEL_LEN", "4096")), # Maximum sequence length (in tokens) the model can handle, including both input and output tokens.
max_num_seqs=int(os.getenv("MAX_NUM_SEQ", "512")), # Maximum number of sequences (requests) to process in parallel.
trust_remote_code=True, # Allow execution of untrusted code from the model repository (use with caution).
enable_chunked_prefill=False, # Disable chunked prefill to avoid compatibility issues with prefix caching.
tokenizer="hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",
max_parallel_loading_workers=1, # Number of parallel workers to load the model concurrently.
pipeline_parallel_size=int(os.getenv("PIPELINE_PARALLEL_SIZE", "2")), # Number of pipeline parallelism stages; typically set to 1 unless using model parallelism.
tensor_parallel_size=int(os.getenv("TENSOR_PARALLEL_SIZE", "1")), # Number of tensor parallelism stages; typically set to 1 unless using model parallelism.
enable_prefix_caching=False, # Enable prefix caching to improve performance for similar prompt prefixes.
enforce_eager=True,
disable_log_requests=False,
quantization='awq_marlin'
)
self.engine = AsyncLLMEngine.from_engine_args(args)
self.max_model_len = args.max_model_len
logger.info(f"VLLM Engine initialized with max_model_len: {self.max_model_len}")
async def stream_results(self, results_generator) -> AsyncGenerator[bytes, None]:
num_returned = 0
async for request_output in results_generator:
text_outputs = [output.text for output in request_output.outputs]
assert len(text_outputs) == 1
text_output = text_outputs[0][num_returned:]
ret = {"text": text_output}
yield (json.dumps(ret) + "\n").encode("utf-8")
num_returned += len(text_output)
async def may_abort_request(self, request_id) -> None:
await self.engine.abort(request_id)
async def __call__(self, request: Request) -> Response:
try:
request_dict = await request.json()
except json.JSONDecodeError:
return JSONResponse(status_code=400, content={"error": "Invalid JSON in request body"})
context_length = request_dict.pop("context_length", 4096) # Default to 8k
# Ensure context length is either 8k or 32k
if context_length not in [4096, 32768]:
context_length = 4096 # Default to 8k if invalid
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
# Get model config and tokenizer
model_config = await self.engine.get_model_config()
tokenizer = await self.engine.get_tokenizer()
input_token_ids = tokenizer.encode(prompt)
input_tokens = len(input_token_ids)
max_possible_new_tokens = min(context_length, model_config.max_model_len) - input_tokens
max_new_tokens = min(request_dict.get("max_tokens", 4096), max_possible_new_tokens)
sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=request_dict.get("temperature", 0.7),
top_p=request_dict.get("top_p", 0.9),
top_k=request_dict.get("top_k", 50),
stop=request_dict.get("stop", None),
)
request_id = random_uuid()
logger.info(f"Processing request {request_id} with {input_tokens} input tokens")
results_generator = self.engine.generate(prompt, sampling_params, request_id)
if stream:
background_tasks = BackgroundTasks()
# Using background_tasks to abort the request
# if the client disconnects.
background_tasks.add_task(self.may_abort_request, request_id)
return StreamingResponse(
self.stream_results(results_generator), background=background_tasks
)
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
logger.warning(f"Client disconnected for request {request_id}")
return Response(status_code=499)
final_output = request_output
assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
logger.info(f"Completed request {request_id}")
return Response(content=json.dumps(ret))
deployment = VLLMDeployment.bind()
@jradikk I was able to get pipeline parallelism working using Llama 3 8B and 2 L4 GPUs across two different nodes, here's the YAML I used:
apiVersion: ray.io/v1
kind: RayService
metadata:
name: llama-3-8b
spec:
serveConfigV2: |
applications:
- name: llm
route_prefix: /
import_path: ray-operator.config.samples.vllm.serve:model
deployments:
- name: VLLMDeployment
num_replicas: 1
ray_actor_options:
num_cpus: 8
# NOTE: num_gpus is set automatically based on TENSOR_PARALLELISM
runtime_env:
working_dir: "https://github.com/andrewsykim/kuberay/archive/vllm-pipeline-parallelism.zip"
pip: ["vllm==0.5.4"]
env_vars:
MODEL_ID: "meta-llama/Meta-Llama-3-8B-Instruct"
TENSOR_PARALLELISM: "1"
PIPELINE_PARALLELISM: "2"
rayClusterConfig:
headGroupSpec:
rayStartParams:
dashboard-host: '0.0.0.0'
template:
spec:
containers:
- name: ray-head
image: rayproject/ray-ml:2.33.0.914af0-py311
resources:
limits:
cpu: "2"
memory: "8Gi"
requests:
cpu: "2"
memory: "8Gi"
ports:
- containerPort: 6379
name: gcs-server
- containerPort: 8265
name: dashboard
- containerPort: 10001
name: client
- containerPort: 8000
name: serve
env:
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-secret
key: hf_api_token
workerGroupSpecs:
- replicas: 2
minReplicas: 0
maxReplicas: 4
groupName: gpu-group
rayStartParams: {}
template:
spec:
containers:
- name: llm
image: rayproject/ray-ml:2.33.0.914af0-py311
env:
- name: HUGGING_FACE_HUB_TOKEN
valueFrom:
secretKeyRef:
name: hf-secret
key: hf_api_token
resources:
limits:
cpu: "8"
memory: "20Gi"
nvidia.com/gpu: "1"
requests:
cpu: "8"
memory: "20Gi"
nvidia.com/gpu: "1"
# Please add the following taints to the GPU node.
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"
Prompt test:
$ curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What are the top 5 most popular programming languages? Please be brief."}
],
"temperature": 0.7
}'
Handling connection for 8000
{"id":"chat-bda4940274ff40248a758b84473a33d2","object":"chat.completion","created":1725988067,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Based on the TIOBE Index, the top 5 most popular programming languages are:\n\n1. JavaScript\n2. Python\n3. C++\n4. C#\n5. Java\n\nNote: The popularity of programming languages can vary depending on the source and the criteria used to measure popularity.","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":36,"total_tokens":96,"completion_tokens":60}}
Ray dashboard showing both GPUs being used:
I also opened https://github.com/ray-project/kuberay/pull/2370 to update the example source code to handle configuration of --pipeline-parallel-size
I will also test with Llama 3.1 70B across 2 A100s and the example you shared
@andrewsykim I was able to run Llama 3.1 8B using your example. However, once I changed the model to hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4
I got the following error
The deployment failed to start 6 times in a row. This may be due to a problem with its constructor or initial health check failing. See controller logs for details. Retrying after 8.0 seconds. Error:
[36mray::ServeReplica:llm:VLLMDeployment.initialize_and_get_metadata()[39m (pid=15354, ip=10.233.87.150, actor_id=02cc85f77291d589afbbf65b01000000, repr=<ray.serve._private.replica.ServeReplica:llm:VLLMDeployment object at 0x7fd5f4792410>)
File "/home/ray/anaconda3/lib/python3.11/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/serve/_private/replica.py", line 631, in initialize_and_get_metadata
raise RuntimeError(traceback.format_exc()) from None
RuntimeError: Traceback (most recent call last):
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/serve/_private/replica.py", line 609, in initialize_and_get_metadata
await self._user_callable_wrapper.initialize_callable()
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/serve/_private/replica.py", line 901, in initialize_callable
await self._call_func_or_gen(
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/serve/_private/replica.py", line 867, in _call_func_or_gen
result = callable(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/serve/api.py", line 219, in __init__
cls.__init__(self, *args, **kwargs)
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/working_dir_files/https_redacted/ray-operator/config/samples/vllm/serve.py", line 37, in __init__
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/pip/ee1b45ff5b20a2bd0131f4bec0625b8316794091/virtualenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 726, in from_engine_args
engine_config = engine_args.create_engine_config()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/pip/ee1b45ff5b20a2bd0131f4bec0625b8316794091/virtualenv/lib/python3.11/site-packages/vllm/engine/arg_utils.py", line 792, in create_engine_config
model_config = ModelConfig(
^^^^^^^^^^^^
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/pip/ee1b45ff5b20a2bd0131f4bec0625b8316794091/virtualenv/lib/python3.11/site-packages/vllm/config.py", line 238, in __init__
self._verify_quantization()
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/pip/ee1b45ff5b20a2bd0131f4bec0625b8316794091/virtualenv/lib/python3.11/site-packages/vllm/config.py", line 297, in _verify_quantization
quantization_override = method.override_quantization_method(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/pip/ee1b45ff5b20a2bd0131f4bec0625b8316794091/virtualenv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/awq_marlin.py", line 82, in override_quantization_method
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/pip/ee1b45ff5b20a2bd0131f4bec0625b8316794091/virtualenv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/awq_marlin.py", line 127, in is_awq_marlin_compatible
return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits],
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/pip/ee1b45ff5b20a2bd0131f4bec0625b8316794091/virtualenv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils.py", line 78, in check_marlin_supported
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/pip/ee1b45ff5b20a2bd0131f4bec0625b8316794091/virtualenv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils.py", line 55, in _check_marlin_supported
major, minor = current_platform.get_device_capability()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/pip/ee1b45ff5b20a2bd0131f4bec0625b8316794091/virtualenv/lib/python3.11/site-packages/vllm/platforms/cuda.py", line 101, in get_device_capability
physical_device_id = device_id_to_physical_device_id(device_id)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/ray/session_2024-09-11_05-57-04_783764_1/runtime_resources/pip/ee1b45ff5b20a2bd0131f4bec0625b8316794091/virtualenv/lib/python3.11/site-packages/vllm/platforms/cuda.py", line 88, in device_id_to_physical_device_id
raise RuntimeError("CUDA_VISIBLE_DEVICES is set to empty string,"
RuntimeError: CUDA_VISIBLE_DEVICES is set to empty string, which means GPU support is disabled.
I have a feeling, it is related to the quantization method being marlin
@jradikk I tried with hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4
as well and saw the same issue. I'm trying with Llama 3.1 70B now but I'm running into CUDA out of memory issues with 2 A100s
I've tried hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4
and hugging-quants/Meta-Llama-3.1-70B-Instruct-GPTQ-INT4
with the same results. So it doesn't seem directly connected to the size or the quantization type of the model. But might be related to the actual fact of a model being quantized
Not sure it's related, but why do you use -q marlin
when not using kuberay but in your serve code you set quantization='awq_marlin'
. Is there notable difference between these two that would matter?
It's an artefact of me tweaking every possible config :) It doesn't really seem to matter. I tried awq_marlin, marlin, gptq (for a differently quantized model) and the result was the same
@jradikk I did some more testing with Llama 3.1 70B, I was able to run it with --tensor-parallel-size=2 and --pipeline-parallel-size=4 using 8 A100 GPUs. So I think it's safe to say the issue is with quantization as you mentioned already. I'm not sure why it would work as a job but not as a serve application with KubeRay. The python3 -m vllm.entrypoints.openai.api_server
command and the serve code / config seem identical to me but I'm not familiar enough with vLLM to identify any other notable differences.
I suggest opening an issue in the vLLM project to see if anyone there can identify notable differences in the two ways of deploying your model.
Search before asking
KubeRay Component
ray-operator, apiserver
What happened + What you expected to happen
I'm trying to launch a model with distributed inference using 2 worker pods with 1 GPU each
I have successfully launched vLLM with ray, using this guide from vllm. This is the command I used to launch quantized llama3.1:70b -
python3 -m vllm.entrypoints.openai.api_server --port 8080 --served-model-name llama3.1:70b --model hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4 --max-model-len 4096 --tokenizer hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4 --dtype half -q marlin --tensor-parallel-size 1 --pipeline-parallel-size 2
. It successfully runs the model over 3 pods (ray head + 2 ray workers), where each worker has access to its own GPU (46GB) However, I'm unable to do the same thing with Kuberay, launching a model via RayService. If i use the same parameters, I keep getting the following error, although I can see a GPU usingnvidia-smi
or even in a Cluster tab in dashboardValueError: Current node has no GPU available. current_node_resource={'node:10.233.87.77': 1.0, 'object_store_memory': 10762244505.0, 'accelerator_type:A40': 1.0, 'memory': 36000000000.0}. vLLM engine cannot start without GPU. Make sure you have at least 1 GPU available in a node current_node_id='2debcbc74912711d8a69aa26e5b7292968d13919f1a2002de2558007' current_ip='10.233.87.77'.
Reproduction script
Anything else
No response
Are you willing to submit a PR?