Openai style api for open large language models, using LLMs just as chatgpt! Support for LLaMA, LLaMA-2, BLOOM, Falcon, Baichuan, Qwen, Xverse, SqlCoder, CodeLLaMA, ChatGLM, ChatGLM2, ChatGLM3 etc. 开源大模型的统一后端接口
提交前必须检查以下项目 | The following items must be checked before submission
[X] 请确保使用的是仓库最新代码(git pull),一些问题已被解决和修复。 | Make sure you are using the latest code from the repository (git pull), some issues have already been addressed and fixed.
[X] 我已阅读项目文档和FAQ章节并且已在Issue中对问题进行了搜索,没有找到相似问题和解决方案 | I have searched the existing issues / discussions
# This script benefits from https://github.com/xusenlinzy/api-for-open-llm. Thanks for their wonderful works.
import json
import os
import time
import traceback
import uuid
from abc import ABC
from argparse import ArgumentParser
from contextlib import asynccontextmanager
from enum import Enum, IntEnum
from functools import lru_cache, partial
from threading import Thread
from types import MethodType
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
import gc
import anyio
import pydantic
import torch
import uvicorn
from anyio.streams.memory import MemoryObjectSendStream
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from llava.conversation import conv_templates
from llava.mm_utils import (
expand2square,
get_model_name_from_path,
load_pretrained_model,
tokenizer_image_token,
)
from llava.model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, key_info
from loguru import logger
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
ChatCompletionMessageParam,
ChatCompletionToolChoiceOptionParam,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.chat.completion_create_params import ResponseFormat
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from pydantic import BaseModel
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool
from transformers import PreTrainedModel, PreTrainedTokenizer, TextIteratorStreamer
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION = "function"
TOOL = "tool"
class ErrorResponse(BaseModel):
object: str = "error"
message: str
code: int
class ErrorCode(IntEnum):
"""
https://platform.openai.com/docs/guides/error-codes/api-errors
"""
VALIDATION_TYPE_ERROR = 40001
INVALID_AUTH_KEY = 40101
INCORRECT_AUTH_KEY = 40102
NO_PERMISSION = 40103
INVALID_MODEL = 40301
PARAM_OUT_OF_RANGE = 40302
CONTEXT_OVERFLOW = 40303
RATE_LIMIT = 42901
QUOTA_EXCEEDED = 42902
ENGINE_OVERLOADED = 42903
INTERNAL_ERROR = 50001
CUDA_OUT_OF_MEMORY = 50002
GRADIO_REQUEST_ERROR = 50003
GRADIO_STREAM_UNKNOWN_ERROR = 50004
CONTROLLER_NO_WORKER = 50005
CONTROLLER_WORKER_TIMEOUT = 50006
class ChatCompletionCreateParams(BaseModel):
messages: List[ChatCompletionMessageParam]
"""A list of messages comprising the conversation so far.
[Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).
"""
model: str
"""ID of the model to use.
See the
[model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility)
table for details on which models work with the Chat API.
"""
frequency_penalty: Optional[float] = 0.0
"""Number between -2.0 and 2.0.
Positive values penalize new tokens based on their existing frequency in the
text so far, decreasing the model's likelihood to repeat the same line verbatim.
[See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
"""
function_call: Optional[FunctionCall] = None
"""Deprecated in favor of `tool_choice`.
Controls which (if any) function is called by the model. `none` means the model
will not call a function and instead generates a message. `auto` means the model
can pick between generating a message or calling a function. Specifying a
particular function via `{"name": "my_function"}` forces the model to call that
function.
`none` is the default when no functions are present. `auto`` is the default if
functions are present.
"""
functions: Optional[List] = None
"""Deprecated in favor of `tools`.
A list of functions the model may generate JSON inputs for.
"""
logit_bias: Optional[Dict[str, int]] = None
"""Modify the likelihood of specified tokens appearing in the completion.
Accepts a JSON object that maps tokens (specified by their token ID in the
tokenizer) to an associated bias value from -100 to 100. Mathematically, the
bias is added to the logits generated by the model prior to sampling. The exact
effect will vary per model, but values between -1 and 1 should decrease or
increase likelihood of selection; values like -100 or 100 should result in a ban
or exclusive selection of the relevant token.
"""
max_tokens: Optional[int] = None
"""The maximum number of [tokens](/tokenizer) to generate in the chat completion.
The total length of input tokens and generated tokens is limited by the model's
context length.
[Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
for counting tokens.
"""
n: Optional[int] = 1
"""How many chat completion choices to generate for each input message."""
presence_penalty: Optional[float] = 0.0
"""Number between -2.0 and 2.0.
Positive values penalize new tokens based on whether they appear in the text so
far, increasing the model's likelihood to talk about new topics.
[See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details)
"""
response_format: Optional[ResponseFormat] = None
"""An object specifying the format that the model must output.
Used to enable JSON mode.
"""
seed: Optional[int] = None
"""This feature is in Beta.
If specified, our system will make a best effort to sample deterministically,
such that repeated requests with the same `seed` and parameters should return
the same result. Determinism is not guaranteed, and you should refer to the
`system_fingerprint` response parameter to monitor changes in the backend.
"""
stop: Optional[Union[str, List[str]]] = None
"""Up to 4 sequences where the API will stop generating further tokens."""
temperature: Optional[float] = 0.9
"""What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random, while lower values like
0.2 will make it more focused and deterministic.
We generally recommend altering this or `top_p` but not both.
"""
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None
"""
Controls which (if any) function is called by the model. `none` means the model
will not call a function and instead generates a message. `auto` means the model
can pick between generating a message or calling a function. Specifying a
particular function via
`{"type: "function", "function": {"name": "my_function"}}` forces the model to
call that function.
`none` is the default when no functions are present. `auto` is the default if
functions are present.
"""
tools: Optional[List] = None
"""A list of tools the model may call.
Currently, only functions are supported as a tool. Use this to provide a list of
functions the model may generate JSON inputs for.
"""
top_p: Optional[float] = 1.0
"""
An alternative to sampling with temperature, called nucleus sampling, where the
model considers the results of the tokens with top_p probability mass. So 0.1
means only the tokens comprising the top 10% probability mass are considered.
We generally recommend altering this or `temperature` but not both.
"""
user: Optional[str] = None
"""
A unique identifier representing your end-user, which can help OpenAI to monitor
and detect abuse.
[Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
"""
stream: Optional[bool] = False
"""If set, partial message deltas will be sent, like in ChatGPT.
Tokens will be sent as data-only
[server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
as they become available, with the stream terminated by a `data: [DONE]`
message.
[Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).
"""
# Addictional parameters
repetition_penalty: Optional[float] = 1.03
"""The parameter for repetition penalty. 1.0 means no penalty.
See[this paper](https://arxiv.org / pdf / 1909.05858.pdf) for more details.
"""
typical_p: Optional[float] = None
"""Typical Decoding mass.
See[Typical Decoding for Natural Language Generation](https://arxiv.org / abs / 2202.00666) for more information
"""
watermark: Optional[bool] = False
"""Watermarking with [A Watermark for Large Language Models](https://arxiv.org / abs / 2301.10226)
"""
best_of: Optional[int] = 1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = None
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
min_p: Optional[float] = 0.0
def torch_gc() -> None:
r"""
Collects GPU memory.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
torch_gc()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/v1/chat/completions")
async def create_chat_completion(
request: ChatCompletionCreateParams, raw_request: Request
):
if len(request.messages) < 1 or request.messages[-1]["role"] == Role.ASSISTANT:
raise HTTPException(status_code=400, detail="Invalid request")
request = await handle_request(request, engine.template.stop)
request.max_tokens = request.max_tokens or 1024
params = model_dump(request, exclude={"messages"})
params.update(dict(prompt_or_messages=request.messages, echo=False))
logger.debug(f"==== request ====\n{params}")
iterator_or_completion = await run_in_threadpool(
engine.create_chat_completion, params
)
if isinstance(iterator_or_completion, Iterator):
# It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)
# If no exception was raised from first_response, we can assume that
# the iterator is valid, and we can use it to stream the response.
def iterator() -> Iterator:
yield first_response
yield from iterator_or_completion
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial(
get_event_publisher,
request=raw_request,
inner_send_chan=send_chan,
iterator=iterator(),
),
)
else:
return iterator_or_completion
server_error_msg = (
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)
@torch.inference_mode()
def generate_stream(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
params: Dict[str, Any],
):
input_ids = params.get("inputs")
image_tensor = params.get("image_tensor")
has_image = params.get("has_image", False)
model_name = params.get("model", "llm")
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", 40))
max_new_tokens = int(params.get("max_tokens", 1024))
stop_token_ids = params.get("stop_token_ids") or []
if tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(tokenizer.eos_token_id)
stop_strings = params.get("stop", [])
input_echo_len = len(input_ids)
device = model.device
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
if has_image:
image_tensor = torch.tensor(
image_tensor, dtype=torch.bfloat16, device=device
).unsqueeze(0)
generation_kwargs = dict(
input_ids=input_ids,
images=image_tensor,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.pad_token_id,
)
if temperature <= 1e-5:
generation_kwargs["do_sample"] = False
generation_kwargs.pop("top_k")
streamer = TextIteratorStreamer(
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs["streamer"] = streamer
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
generated_text, func_call_found = "", False
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
previous_text = ""
for i, new_text in enumerate(streamer):
generated_text += new_text
generated_text, stop_found = apply_stopping_strings(
generated_text, stop_strings
)
if generated_text and generated_text[-1] != "�":
delta_text = generated_text[len(previous_text) :]
previous_text = generated_text
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"delta": delta_text,
"text": generated_text,
"logprobs": None,
"finish_reason": "function_call" if func_call_found else None,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
}
if stop_found:
break
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"delta": "",
"text": generated_text,
"logprobs": None,
"finish_reason": "stop",
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
}
class DefaultEngine(ABC):
"""基于原生 transformers 实现的模型引擎"""
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
image_processor,
device: Union[str, torch.device],
model_name: str,
) -> None:
"""
Initialize the Default class.
Args:
model (PreTrainedModel): The pre-trained model.
tokenizer (PreTrainedTokenizer): The tokenizer for the model.
device (Union[str, torch.device]): The device to use for inference.
model_name (str): The name of the model.
"""
self.model = model
self.tokenizer = tokenizer
self.image_processor = image_processor
self.device = model.device if hasattr(model, "device") else torch.device(device)
self.model_name = model_name.lower()
self.template = YiAITemplate()
self._prepare_for_generate()
def _prepare_for_generate(self) -> None:
"""
Prepare the object for text generation.
1. Sets the appropriate generate stream function based on the model name and type.
2. Updates the context length if necessary.
3. Checks and constructs the prompt.
4. Sets the context length if it is not already set.
"""
self.generate_stream_func = generate_stream
self.context_len = get_context_length(self.model.config)
def convert_to_inputs(
self, prompt_or_messages: Union[List[ChatCompletionMessageParam], str]
) -> Tuple[
Union[List[int], Dict[str, Any]], Union[List[ChatCompletionMessageParam], str]
]:
"""
Convert the prompt or messages into input format for the model.
Args:
prompt_or_messages: The prompt or messages to be converted.
Returns:
Tuple containing the converted inputs and the prompt or messages.
"""
conv = conv_templates["mm_default"].copy()
stop_str = conv.sep
image_file = ""
has_image = False
for message in prompt_or_messages:
role = message["role"]
content = message["content"]
if role != "user":
conv.append_message(conv.roles[1], content)
continue
elif isinstance(content, str):
conv.append_message(conv.roles[0], content)
continue
num_images = 0
for item in content:
if item["type"] == "image_url":
num_images += 1
image_file = item["image_url"]["url"]
has_image = True
elif item["type"] == "text":
prompt = item["text"]
if image_file != "" and image_file != None:
query = DEFAULT_IMAGE_TOKEN * num_images + "\n" + prompt
conv.append_message(conv.roles[0], query)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX)
if image_file != "" and image_file != None:
if image_file.startswith("http"): # url
from io import BytesIO
import requests
response = requests.get(image_file)
if response.status_code == 200:
image_bytes = BytesIO(response.content)
image = Image.open(image_bytes)
else:
image = Image.open(image_file) # local path
if getattr(self.model.config, "image_aspect_ratio", None) == "pad":
image = expand2square(
image, tuple(int(x * 255) for x in self.image_processor.image_mean)
)
image_tensor = self.image_processor.preprocess(image)["pixel_values"][0]
else:
image_tensor = None
return input_ids, image_tensor, stop_str, has_image
def _generate(self, params: Dict[str, Any]) -> Iterator[dict]:
"""
Generates text based on the given parameters.
Args:
params (Dict[str, Any]): A dictionary containing the parameters for text generation.
Yields:
Iterator: A dictionary containing the generated text and error code.
"""
prompt_or_messages = params.get("prompt_or_messages")
input_ids, image_tensor, stop_str, has_image = self.convert_to_inputs(
prompt_or_messages
)
params.update(
dict(inputs=input_ids, image_tensor=image_tensor, has_image=has_image)
)
params["stop"].append(stop_str)
try:
for output in self.generate_stream_func(self.model, self.tokenizer, params):
output["error_code"] = 0
yield output
except torch.cuda.OutOfMemoryError as e:
yield {
"text": f"{server_error_msg}\n\n({e})",
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
}
except (ValueError, RuntimeError) as e:
traceback.print_exc()
yield {
"text": f"{server_error_msg}\n\n({e})",
"error_code": ErrorCode.INTERNAL_ERROR,
}
def _create_chat_completion_stream(
self, params: Dict[str, Any]
) -> Iterator[ChatCompletionChunk]:
"""
Creates a chat completion stream.
Args:
params (Dict[str, Any]): The parameters for generating the chat completion.
Yields:
Dict[str, Any]: The output of the chat completion stream.
"""
_id, _created, _model = None, None, None
has_function_call = False
for i, output in enumerate(self._generate(params)):
if output["error_code"] != 0:
yield output
return
_id, _created, _model = output["id"], output["created"], output["model"]
if i == 0:
choice = ChunkChoice(
index=0,
delta=ChoiceDelta(role="assistant", content=""),
finish_reason=None,
logprobs=None,
)
yield ChatCompletionChunk(
id=f"chat{_id}",
choices=[choice],
created=_created,
model=_model,
object="chat.completion.chunk",
)
finish_reason = output["finish_reason"]
if len(output["delta"]) == 0 and finish_reason != "function_call":
continue
delta = ChoiceDelta(content=output["delta"])
choice = ChunkChoice(
index=0,
delta=delta,
finish_reason=finish_reason,
logprobs=None,
)
yield ChatCompletionChunk(
id=f"chat{_id}",
choices=[choice],
created=_created,
model=_model,
object="chat.completion.chunk",
)
if not has_function_call:
choice = ChunkChoice(
index=0,
delta=ChoiceDelta(),
finish_reason="stop",
logprobs=None,
)
yield ChatCompletionChunk(
id=f"chat{_id}",
choices=[choice],
created=_created,
model=_model,
object="chat.completion.chunk",
)
def _create_chat_completion(
self, params: Dict[str, Any]
) -> Union[ChatCompletion, JSONResponse]:
"""
Creates a chat completion based on the given parameters.
Args:
params (Dict[str, Any]): The parameters for generating the chat completion.
Returns:
ChatCompletion: The generated chat completion.
"""
last_output = None
for output in self._generate(params):
last_output = output
if last_output["error_code"] != 0:
return create_error_response(last_output["error_code"], last_output["text"])
finish_reason = "stop"
message = ChatCompletionMessage(
role="assistant",
content=last_output["text"].strip(),
)
choice = Choice(
index=0,
message=message,
finish_reason=finish_reason,
logprobs=None,
)
usage = model_parse(CompletionUsage, last_output["usage"])
return ChatCompletion(
id=f"chat{last_output['id']}",
choices=[choice],
created=last_output["created"],
model=last_output["model"],
object="chat.completion",
usage=usage,
)
def create_chat_completion(
self,
params: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[Iterator[ChatCompletionChunk], ChatCompletion]:
params = params or {}
params.update(kwargs)
return (
self._create_chat_completion_stream(params)
if params.get("stream", False)
else self._create_chat_completion(params)
)
@property
def stop(self):
"""
Gets the stop property of the prompt adapter.
Returns:
The stop property of the prompt adapter, or None if it does not exist.
"""
return self.template.stop if hasattr(self.template, "stop") else None
class YiAITemplate(ABC):
"""https://huggingface.co/01-ai/Yi-34B-Chat/blob/main/tokenizer_config.json"""
name = "yi"
system_prompt: Optional[str] = ""
allow_models = ["yi"]
stop = {
"strings": ["<|endoftext|>", "<|im_end|>"],
"token_ids": [
2,
6,
7,
8,
], # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"
}
function_call_available: Optional[bool] = False
def apply_chat_template(
self,
conversation: List[ChatCompletionMessageParam],
add_generation_prompt: bool = True,
) -> str:
"""
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a prompt.
Args:
conversation (List[ChatCompletionMessageParam]): A Conversation object or list of dicts
with "role" and "content" keys, representing the chat history so far.
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
the start of an assistant message. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the
template for this argument to have any effect.
Returns:
`str`: A prompt, which is ready to pass to the tokenizer.
"""
# Compilation function uses a cache to avoid recompiling the same template
compiled_template = _compile_jinja_template(self.template)
return compiled_template.render(
messages=conversation,
add_generation_prompt=add_generation_prompt,
system_prompt=self.system_prompt,
)
@property
def template(self) -> str:
return (
"{% for message in messages %}"
"{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\\n' }}"
"{% endif %}"
)
def postprocess_messages(
self, messages: List[ChatCompletionMessageParam]
) -> List[Dict[str, Any]]:
return messages
def parse_assistant_response(
self, output: StopIteration
) -> Tuple[str, Optional[Union[str, Dict[str, Any]]]]:
return output, None
@lru_cache
def _compile_jinja_template(chat_template: str):
"""
Compile a Jinja template from a string.
Args:
chat_template (str): The string representation of the Jinja template.
Returns:
jinja2.Template: The compiled Jinja template.
Examples:
>>> template_string = "Hello, {{ name }}!"
>>> template = _compile_jinja_template(template_string)
"""
try:
from jinja2.exceptions import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")
def raise_exception(message):
raise TemplateError(message)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)
async def handle_request(
request: Union[ChatCompletionCreateParams], stop: Dict[str, Any] = None
) -> Union[Union[ChatCompletionCreateParams], JSONResponse]:
error_check_ret = check_requests(request)
if error_check_ret is not None:
raise error_check_ret
# stop settings
_stop, _stop_token_ids = [], []
if stop is not None:
_stop_token_ids = stop.get("token_ids", [])
_stop = stop.get("strings", [])
request.stop = request.stop or []
if isinstance(request.stop, str):
request.stop = [request.stop]
if request.functions:
request.stop.append("Observation:")
request.stop = list(set(_stop + request.stop))
request.stop_token_ids = request.stop_token_ids or []
request.stop_token_ids = list(set(_stop_token_ids + request.stop_token_ids))
return request
def check_requests(
request: Union[ChatCompletionCreateParams],
) -> Optional[JSONResponse]:
# Check all params
if request.max_tokens is not None and request.max_tokens <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
)
if request.n is not None and request.n <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.n} is less than the minimum of 1 - 'n'",
)
if request.temperature is not None and request.temperature < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
)
if request.temperature is not None and request.temperature > 2:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
)
if request.top_p is not None and request.top_p < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
)
if request.top_p is not None and request.top_p > 1:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
)
if request.stop is None or isinstance(request.stop, (str, list)):
return None
else:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.stop} is not valid under any of the given schemas - 'stop'",
)
def create_error_response(code: int, message: str) -> JSONResponse:
return JSONResponse(
model_dump(ErrorResponse(message=message, code=code)), status_code=500
)
async def get_event_publisher(
request: Request,
inner_send_chan: MemoryObjectSendStream,
iterator: Union[Iterator, AsyncIterator],
):
async with inner_send_chan:
try:
async for chunk in iterate_in_threadpool(iterator):
if isinstance(chunk, BaseModel):
chunk = model_json(chunk)
elif isinstance(chunk, dict):
chunk = json.dumps(chunk, ensure_ascii=False)
await inner_send_chan.send(dict(data=chunk))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
logger.info("disconnected")
with anyio.move_on_after(1, shield=True):
logger.info(
f"Disconnected from client (via refresh/close) {request.client}"
)
raise e
def create_generate_model(args):
"""get generate model for chat or completion."""
model_path = os.path.expanduser(args.model_path)
key_info["model_path"] = model_path
get_model_name_from_path(model_path)
tokenizer, model, image_processor, _ = load_pretrained_model(model_path)
logger.info("Using default engine")
return DefaultEngine(
model, tokenizer, image_processor, "cuda", model_name=args.model_name
)
# --------------- Pydantic v2 compatibility ---------------
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
def model_json(model: pydantic.BaseModel, **kwargs) -> str:
if PYDANTIC_V2:
return model.model_dump_json(**kwargs)
return model.json(**kwargs) # type: ignore
def model_dump(model: pydantic.BaseModel, **kwargs) -> Dict[str, Any]:
if PYDANTIC_V2:
return model.model_dump(**kwargs)
return cast(
"dict[str, Any]",
model.dict(**kwargs),
)
def model_parse(model: Type[pydantic.BaseModel], data: Any) -> pydantic.BaseModel:
if PYDANTIC_V2:
return model.model_validate(data)
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
# Models don't use the same configuration key for determining the maximum
# sequence length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these, and we
# have a preference for which value gets used.
SEQUENCE_LENGTH_KEYS = [
"max_sequence_length",
"seq_length",
"max_position_embeddings",
"max_seq_len",
"model_max_length",
]
def get_context_length(config) -> int:
"""Get the context length of a model from a huggingface model config."""
rope_scaling = getattr(config, "rope_scaling", None)
rope_scaling_factor = config.rope_scaling["factor"] if rope_scaling else 1
for key in SEQUENCE_LENGTH_KEYS:
val = getattr(config, key, None)
if val is not None:
return int(rope_scaling_factor * val)
return 2048
def apply_stopping_strings(reply: str, stop_strings: List[str]) -> Tuple[str, bool]:
"""
Apply stopping strings to the reply and check if a stop string is found.
Args:
reply (str): The reply to apply stopping strings to.
stop_strings (List[str]): The list of stopping strings to check for.
Returns:
Tuple[str, bool]: A tuple containing the modified reply and a boolean indicating if a stop string was found.
"""
stop_found = False
for string in stop_strings:
idx = reply.find(string)
if idx != -1:
reply = reply[:idx]
stop_found = True
break
if not stop_found:
# If something like "\nYo" is generated just before "\nYou: is completed, trim it
for string in stop_strings:
for j in range(len(string) - 1, 0, -1):
if reply[-j:] == string[:j]:
reply = reply[:-j]
break
else:
continue
break
return reply, stop_found
def _get_args():
parser = ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000, help="Demo server port.")
# model related
parser.add_argument(
"--model-path",
type=str,
default="01-ai/Yi-VL-6B",
)
parser.add_argument("--model-name", type=str, default="yi-vl")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = _get_args()
engine = create_generate_model(args)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
提交前必须检查以下项目 | The following items must be checked before submission
问题类型 | Type of problem
模型推理和部署 | Model inference and deployment
操作系统 | Operating system
Linux
详细描述问题 | Detailed description of the problem
我参考您的代码使用DefaultEngine部署了一个Openai style的服务器,但是尽管我在lifespan中调用了torch_gc,但是随着我的请求,显存还是一直在增加,请问这是什么原因呢?我该如何排查啊, 代码如下:
Dependencies
运行日志或截图 | Runtime logs or screenshots