sysid / sse-starlette

BSD 3-Clause "New" or "Revised" License
504 stars 35 forks source link

GPU memory footprint #72

Closed SevenMpp closed 10 months ago

SevenMpp commented 10 months ago

The streaming service was built with Fastapi, postman or program or curl are used for testing. When the request is completely responded to, it still occupies GPU memory resources and has not been reclaimed.How do I solve???

The following code:

@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): global model, tokenizer

if request.messages[-1].role != "user":
    raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
print("query: ",query)
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
    query = prev_messages.pop(0).content + query

history = []
if len(prev_messages) % 2 == 0:
    for i in range(0, len(prev_messages), 2):
        if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
            history.append([prev_messages[i].content, prev_messages[i+1].content])

if request.stream:
    generate = predict(query, history, request.model)
    return EventSourceResponse(generate, media_type="text/event-stream")
   # response = EventSourceResponse(generate)
    #asyncio.create_task(manage_response(response))
    #return response
response, _ = model.chat(tokenizer, query, history=history)
choice_data = ChatCompletionResponseChoice(
    index=0,
    message=ChatMessage(role="assistant", content=response),
    finish_reason="stop"
)

return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")

async def predict(query: str, history: List[List[str]], model_id: str): global model, tokenizer

choice_data = ChatCompletionResponseStreamChoice(
    index=0,
    delta=DeltaMessage(role="assistant"),
    finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))

current_length = 0

for new_response, _ in model.stream_chat(tokenizer, query, history):
    if len(new_response) == current_length:
        continue

    new_text = new_response[current_length:]
    current_length = len(new_response)

    choice_data = ChatCompletionResponseStreamChoice(
        index=0,
        delta=DeltaMessage(content=new_text),
        finish_reason=None
    )
    chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
    yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))

choice_data = ChatCompletionResponseStreamChoice(
    index=0,
    delta=DeltaMessage(),
    finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
#torch_gc()
yield '[DONE]'
await asyncio.sleep(0.0001)
sysid commented 10 months ago

Since sse-starlette is not actively involved in memory management I do not think that your issue is related with sse-starlette. However, if you have reason to disagree, please re-open your issue with your reasoning.

SevenMpp commented 10 months ago

I thought sse-starlettle is a keep-lived connection, so when streaming a request, after the response is completed, the memory used to process the request is still not reclaimed, so I think it is relevant, what do you think??