vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
28.68k stars 4.26k forks source link

[Usage]: Setting max_tokens in chat completion request class returns an empty output #3851

Closed NrKhader closed 2 months ago

NrKhader commented 6 months ago

Your current environment

I am using vllm version 0.3.0 I am using this class ChatCompletionRequest to create the request for my chat completion endpoint Whenever I set the max_tokens to any number and send a large prompt close in length to the max_model_len the endpoint sends and empty output with status_code=200

Can you tell me how the max_tokens is used or related to the max_model_len or if I'm using it incorrectly altogether?

How would you like to use vllm

I am using vllm to create a chat_completions endpoint for inference purposes

hmellor commented 6 months ago

max_tokens comes from https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens

If prompt_tokens + max_tokens > max_model_len, then I believe the output will be cropped so that prompt_tokens + generated_tokens == max_model_len

NrKhader commented 6 months ago

Thank you for responding @hmellor I understood this much, but what happened is that no output was generated at all just a list of empty strings

Something like this:

RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=995), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=52), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=11), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=61), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=11), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=30), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=230), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=837), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=21), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=124), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=248), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=72), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=733), RequestFuncOutput(generated_text='', success=False, latency=0, ttft=0, prompt_len=528),

hmellor commented 6 months ago

Could you share a reproducer so I can try and reproduce the issue myself?

NrKhader commented 6 months ago

You can reproduce this with any model I created a chat_completions endpoint like this:

    async def chat_completions(
        self, messages: Union[str, List[Dict[str, str]]] = TEST_MESSAGES
    ) -> AsyncGenerator[str, None]:

        request = ChatCompletionRequest(
            model="mistral",
            messages=messages,
            temperature=0,
            n=1,
            max_tokens=1024,
            stream=True,
        )
        openai_serving_chat = OpenAIServingChat(
            self.engine,
            "mistral",
            response_role="assistant",
            chat_template="chat_template.txt",
        )
        raw_request = Request(scope={"type": "http"})
        generator = await openai_serving_chat.create_chat_completion(
            request, raw_request=raw_request
        )
        if isinstance(generator, ErrorResponse):
            return JSONResponse(
                content=generator.model_dump(), status_code=generator.code
            )
        if request.stream:
            return StreamingResponse(content=generator, media_type="text/event-stream")
        else:
            return JSONResponse(content=generator.model_dump())

and I tried to send a long prompt is a little less than the max_model_len

The issue is with this max_tokens=1024 parameter When I traced it in the code this was the condition causing it But the error was not showing because of how the exception is handled here and I ended up with an empty string as output

Now this is not the correct behavior as the max_tokens should only dictate the limit if the max_model_len allows for it, so in the case that the prompt is very long the model should just generate the output until it reaches the max_model_len

In short if the (prompt length + output) equals the max_model_len and the output is less than the max_tokens this issue shouldn't appear, instead it should generate a bit of output then stop

I tried to explain this as best I could

NrKhader commented 6 months ago

@hmellor Please check my response above ⬆️

hmellor commented 6 months ago

Ok so the error is that prompt_tokens + max_tokens is exceeding max_model_len.

https://github.com/vllm-project/vllm/blob/0ce0539d4750f9ebcd9b19d7085ca3b934b9ec67/vllm/entrypoints/openai/serving_chat.py#L65-L79

This looks like it should be returning an ErrorResponse, which should be picked up by the following code in the snippet you provided:

        if isinstance(generator, ErrorResponse):
            return JSONResponse(
                content=generator.model_dump(), status_code=generator.code
            )

What is the type of generator when you're raising this error?

NrKhader commented 6 months ago

The generator type is <class 'vllm.entrypoints.openai.protocol.ErrorResponse'>

hmellor commented 6 months ago

Ok so you should be entering that if block and the response should contain the error.

ErrorResponse contains a message field, so your application should be able to read it like a normal response:

https://github.com/vllm-project/vllm/blob/0ce0539d4750f9ebcd9b19d7085ca3b934b9ec67/vllm/entrypoints/openai/protocol.py#L13-L18

hmellor commented 6 months ago

What is str(e) in

https://github.com/vllm-project/vllm/blob/0ce0539d4750f9ebcd9b19d7085ca3b934b9ec67/vllm/entrypoints/openai/serving_chat.py#L78-L79

when the error is triggered?

NrKhader commented 6 months ago

Ok, for the error response, yes it returns the error message normally, my other script that uses this endpoint is the issue

This is the content of str(e):

This model's maximum context length is 4096 tokens. However, you requested 5060 tokens (4036 in the messages, 1024 in the completion). Please reduce the length of the messages or completion.

NrKhader commented 6 months ago

My issue is not with the error message but why it happened, I mean the behavior

My understanding is that the model should only truncate the output when it exceeds the max_tokens

hmellor commented 6 months ago

My understanding is that the model should only truncate the output when it exceeds the max_tokens

This is correct.

However, you are exceeding max_model_len to cause this error.

antoniolanza1996 commented 4 months ago

@NrKhader have u been able to solve it?

Seems to be some out of memory VRAM usage. Indeed, I had the same problem with mistral-7B on 2 T4 GPUs with more than 3K input tokens. However, as soon as I switched on a A100 GPU, all these prompts receive NOT-empty answers.