vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
22.16k stars 3.12k forks source link

Add worker registry service for hosting multiple vllm model through single api gateway #1753

Open tjtanaa opened 7 months ago

tjtanaa commented 7 months ago

I have been using vllm integration from fastchat to host multiple vllm models. However, it does not offer the full capability of vllm. e.g. It does not support beam search.

I would like to propose to add a system like the one in the FastChat to to serve vllm workers so that all its capabilities are exposed e.g. greedy search, beam search, return of log_probs, etc.

I have done modifying the chat completion endpoint. However, it is specifically written for vllm workers, not the transformers model_workers.

The preliminary code can be found at https://github.com/tjtanaa/fastchat-serve/tree/vllm-api-expose-tj

Custom OpenAI API server with best_of, use_beam_search https://github.com/tjtanaa/fastchat-serve/blob/vllm-api-expose-tj/src/fastchat_serve/openai_api_server_vllm.py

Modified vllm worker to support greedy search and beam search. https://github.com/tjtanaa/fastchat-serve/blob/vllm-api-expose-tj/src/fastchat_serve/vllm_worker.py

Would this be helpful to have in this vllm repository? I have opened an issue at https://github.com/lm-sys/FastChat/issues/2709 as I am not sure which repository is better to store this feature.

hmellor commented 6 months ago

An potential solution (for multi-GPU machines) would be for an api_server to have multiple engines. Then, when a new request is received, the api_server sends the incoming prompt to the engine with the shortest waiting request queue.

This solution would:

Does this sound like a reasonable solution @WoosukKwon?

iamhappytoo commented 3 months ago

Hello,

I'm using python multiprocessing to set up multiple vllm engines for loading different models in a single api_server, each engine reside in a distinctive single subprocesses. The code snippets I used are: from torch.multiprocessing import Process, Queue, Event from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.parallel_utils import cupy_utils def completion_with_vllm(model, prompt, params, tensor_parallel_size): ...... global_processes[model] = [] global_task_queues[model] = Queue() global_result_queues[model] = Queue() for local_rank in range(1): # each engine has a single subprocess engine_args = {'model': model_path, 'tensor_parallel_size': TP, 'gpu_memory_utilization': gpu_mem_util} # TP: tensor_parallel_size p = Process(target=self._run_on_gpu_with_vllm, args=(stop_event, global_task_queues[model], global_result_queues[model], model, ENGINE_ARGS, engine_args, local_rank, gpu_id)) p.start() global_processes[model].append(p) new_task = { "model": model, "prompt": prompt, "params_dict": params } global_task_queues[model].put(new_task) all_completions = [] item = global_result_queues[model].get(timeout=500) all_completions.append(item) ...... def _run_on_gpu_with_vllm(self, stop_event, task_queue, result_queue, model, engine_args, local_rank, gpu_id): model_store[model] = LLM(engine_args) ....... When I load model with TP > 1 (TP=2, or 4, or etc), I receive the following error: Process Process-2: Traceback (most recent call last): File "/xxx//llm_svr.py", line 333, in _run_on_gpu_with_vllm model_store[model] = LLM(engine_args) File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/vllm/entrypoints/llm.py", line 109, in init self.llm_engine = LLMEngine.from_engine_args(engine_args) File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 391, in from_engine_args engine = cls(*engine_configs, File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 126, in init self._init_workers_ray(placement_group) File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 304, in _init_workers_ray self._run_workers("init_model", File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 1041, in _run_workers driver_worker_output = getattr(self.driver_worker, File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/vllm/worker/worker.py", line 94, in init_model init_distributed_environment(self.parallel_config, self.rank, File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/vllm/worker/worker.py", line 275, in init_distributed_environment cupy_utils.init_process_group( File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/vllm/model_executor/parallel_utils/cupy_utils.py", line 90, in init_process_group _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port) File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/cupyx/distributed/_nccl_comm.py", line 70, in init self._init_with_tcp_store(n_devices, rank, host, port) File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/cupyx/distributed/_nccl_comm.py", line 88, in _init_with_tcp_store self._store.run(host, port) File "/scratch/zban/my-venv39-cuda12/lib/python3.9/site-packages/cupyx/distributed/_store.py", line 100, in run p.start() File "/usr/lib64/python3.9/multiprocessing/process.py", line 121, in start self._popen = self._Popen(self) File "/usr/lib64/python3.9/multiprocessing/context.py", line 224, in _Popen return _default_context.get_context().Process._Popen(process_obj) File "/usr/lib64/python3.9/multiprocessing/context.py", line 284, in _Popen return Popen(process_obj) File "/usr/lib64/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in init super().init(process_obj) File "/usr/lib64/python3.9/multiprocessing/popen_fork.py", line 19, in init self._launch(process_obj) File "/usr/lib64/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch reduction.dump(process_obj, fp) File "/usr/lib64/python3.9/multiprocessing/reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) TypeError: cannot pickle '_thread.lock' object

I tried to directly run the cupy_utils.init_process_group within the subprocess before calling LLM(**args), the same error reproduced. In contrast, if I run torch.distributed.init_process_group with same world_size and other parameters, it passed through without errors.

So, it looks like the usage of cupy_utils.init_process_group caused above error.

I have been stuck for pretty a while, may I get some guidance from you on if this is technically possible with vllm to achieve the serving of multi models (engines) in a single api_server, and if python multiprocessing is the sound approach (or some other approaches are better)? Thank you so much in advance! (I'm aware of the discussions recommending using a single model for a single server, but that approach is not very suitable for our architecture, and currently I'm just testing the usage with the most simple LLM offline call, if it works, will switch to async engine later. Your answer to this issue will be much appreciated! )