vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
29.48k stars 4.42k forks source link

[Bug]: llava model gets stuck with RuntimeError: Please increase the max_chunk_bytes parameter. #6376

Closed faileon closed 3 months ago

faileon commented 3 months ago

Your current environment

I have the following docker compose service running vLLM and llava-hf/llava-v1.6-mistral-7b-hf

  llava:
    image: vllm/vllm-openai:latest
    container_name: vllm-llava
    runtime: nvidia
    deploy:
      resources:
        reservations:
          devices:
            - capabilities: [gpu]
    volumes:
      - /data/.cache/huggingface:/root/.cache/huggingface
    env_file:
      - .env
    ports:
      - "8002:8000"
    ipc: host
    command: --model llava-hf/llava-v1.6-mistral-7b-hf --tensor-parallel-size 4 --enforce-eager --gpu-memory-utilization 0.35

I have a service sending 5 parallel requests on the exposed /v1/chat/completions, which will seize it with the following error:

RuntimeError: len(serialized_obj)=14904693 larger than the allowed value 4194304, Please increase the max_chunk_bytes parameter.

After which the container is stuck in a state with 5 requests, where it doesnt accept any new requests:

Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 5 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 3.9%, CPU KV cache usage: 0.0%.

I must completely tear down the container and start it again to unstuck it. If I adjust my service to be more gentle - sending just 1 request at a time, it seems to hold steady.

This is an example request that I am sending:

curl --location 'http://domain:8002/v1/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer token' \
--data '{
    "model": "llava-hf/llava-v1.6-mistral-7b-hf",
    "messages": [
      {
        "role": "user",
        "content": [
            {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAAwADIDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD2+iiitCDD1/xPa+H5IUuIpHMoJG2jQPE1t4geZbeKRPKAJ3VyfxN/4+7H/cP86f8ADL/XX3+6v865fay9tydDk9tL2/J0PRaKKK6jrJB0FFA6CikMzdZmkt9Gu5omKyJGSrDsa8j/AOEv17/oIy/nXstxbx3VvJBKMxyDaw9qwf8AhBtB/wCfZv8AvusK1Oc2uV2OWvTqTa5HY8q1DVb7VGRr24eYoMKW7Uafq19pZc2Vw8Jf723vXQeOdFsdGuLVLKMoJFJbJz3p3gbRLHWZboXsZcRgFcHFcXJP2nLfU4fZz9py31Mv/hL9e/6CMv5167pMsk+k2ssrFneMFie5rJ/4QbQf+fZv++637eCO2t0giGEQbVHtXbRpzg3zO530KdSDfO7lgdBRQOgorc6SOiiimI4H4h6deX1zZm1t5JQqEEqM45p3w80+8sZbw3VvJEGUY3DGa7yisfYr2ntLmPsF7T2lwooorY2JB0FFA6CikM//2Q==" }},
            {"type" : "text", "text": "Describe this image in detail please."}
        ]
      }
    ]
  }'

A bit more from the stack trace:

vllm-llava  |     async for res in result_generator:
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 746, in generate
vllm-llava  |     async for output in self._process_request(
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 859, in _process_request
vllm-llava  |     raise e
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 855, in _process_request
vllm-llava  |     async for request_output in stream:
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 90, in __anext__
vllm-llava  |     raise result
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 43, in _log_task_completion
vllm-llava  |     return_value = task.result()
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 595, in run_engine_loop
vllm-llava  |     result = task.result()
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 540, in engine_step
vllm-llava  |     request_outputs = await self.engine.step_async(virtual_engine)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 241, in step_async
vllm-llava  |     output = await self.model_executor.execute_model_async(
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/distributed_gpu_executor.py", line 173, in execute_model_async
vllm-llava  |     return await self._driver_execute_model_async(execute_model_req)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 160, in _driver_execute_model_async
vllm-llava  |     return await self.driver_exec_model(execute_model_req)
vllm-llava  |   File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
vllm-llava  |     result = self.fn(*self.args, **self.kwargs)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 246, in execute_model
vllm-llava  |     broadcast_tensor_dict(broadcast_data, src=0)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/distributed/communication_op.py", line 32, in broadcast_tensor_dict
vllm-llava  |     return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/distributed/parallel_state.py", line 505, in broadcast_tensor_dict
vllm-llava  |     self.broadcast_object(metadata_list, src=src)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/distributed/parallel_state.py", line 382, in broadcast_object
vllm-llava  |     return self.shm_broadcaster.broadcast_object(obj)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/distributed/device_communicators/shm_broadcast.py", line 266, in broadcast_object
vllm-llava  |     self.enqueue(obj)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/distributed/device_communicators/shm_broadcast.py", line 248, in enqueue
vllm-llava  |     raise RuntimeError(
vllm-llava  | RuntimeError: len(serialized_obj)=18969348 larger than the allowed value 4194304,Please increase the max_chunk_bytes parameter.
vllm-llava  | 
vllm-llava  | The above exception was the direct cause of the following exception:
vllm-llava  | 
vllm-llava  | Traceback (most recent call last):
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/uvicorn/protocols/http/httptools_impl.py", line 399, in run_asgi
vllm-llava  |     result = await app(  # type: ignore[func-returns-value]
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/uvicorn/middleware/proxy_headers.py", line 70, in __call__
vllm-llava  |     return await self.app(scope, receive, send)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/fastapi/applications.py", line 1054, in __call__
vllm-llava  |     await super().__call__(scope, receive, send)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/starlette/applications.py", line 123, in __call__
vllm-llava  |     await self.middleware_stack(scope, receive, send)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 186, in __call__
vllm-llava  |     raise exc
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 164, in __call__
vllm-llava  |     await self.app(scope, receive, _send)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/base.py", line 189, in __call__
vllm-llava  |     with collapse_excgroups():
vllm-llava  |   File "/usr/lib/python3.10/contextlib.py", line 153, in __exit__
vllm-llava  |     self.gen.throw(typ, value, traceback)
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/starlette/_utils.py", line 93, in collapse_excgroups
vllm-llava  |     raise exc
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/starlette/responses.py", line 261, in wrap
vllm-llava  |     await func()
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/starlette/responses.py", line 250, in stream_response
vllm-llava  |     async for chunk in self.body_iterator:
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/serving_chat.py", line 334, in chat_completion_stream_generator
vllm-llava  |     async for res in result_generator:
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 746, in generate
vllm-llava  |     async for output in self._process_request(
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 845, in _process_request
vllm-llava  |     stream = await self.add_request(
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 654, in add_request
vllm-llava  |     self.start_background_loop()
vllm-llava  |   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 476, in start_background_loop
vllm-llava  |     raise AsyncEngineDeadError(
vllm-llava  | vllm.engine.async_llm_engine.AsyncEngineDeadError: Background loop has errored already.
ywang96 commented 3 months ago

@youkaichao FYI, I think shm_broadcast is a bit tricky when it comes to multi-modal inputs...

ywang96 commented 3 months ago

@faileon I talked to Kaichao offline and in fact this won't be an issue if you build the docker image from the main branch and serve the model from that one. This is because of the recently merged https://github.com/vllm-project/vllm/pull/6183.

Feel free to raise another issue if you still see any other error from the latest main branch!

ywang96 commented 3 months ago

Closing this one as it's fixed in #6183