Closed DavdGao closed 1 month ago
自己定义的模型包装器,发现里面就是启动一个参数传参的功能,里面也没有流式,自己写根本就无法起作用,我在里面判断请求方式,自己加stream,发现他根本就不起作用,而且,即便是删除reque_argwargs后面的所有代码,也不影响,即便是有错误,也不会有什么用,这个自定义的模型包装器,到底怎么定义,我只能非流式传输,流式的安装openai改了判断stream,使用迭代器,发现根本就不起作用,也不报错,始终是非流式输出#!/usr/bin/env python
from abc import ABC from typing import Any, Union, Sequence, List from loguru import logger
from agentscope.models import ModelWrapperBase from agentscope.models import ModelResponse import requests from agentscope.constants import _DEFAULT_MESSAGES_KEY,_DEFAULT_MAX_RETRIES,_DEFAULT_RETRY_INTERVAL from agentscope.message import Msg import time import json class GewuModelWrapper(ModelWrapperBase):
model_type = 'ge_api_chat'
def __init__(
self,
config_name="",
model_name="",
api_url="",
headers: dict = None,
max_length: int = 2048,
timeout: int = 30,
# json_args: dict = None,
# post_args: dict = None,
max_retries: int = _DEFAULT_MAX_RETRIES,
messages_key: str = _DEFAULT_MESSAGES_KEY,
retry_interval: int = _DEFAULT_RETRY_INTERVAL,
**kwargs: Any,
) ->None:
super().__init__(config_name=config_name, model_name=model_name)
self.api_url = api_url
self.headers = headers
self.max_length = max_length
self.timeout = timeout
self.json_args = {}
self.post_args = {}
self.max_retries = max_retries
self.messages_key = messages_key
self.retry_interval = retry_interval
def _parse_response(self, response: dict) -> ModelResponse:
"""Parse the response json data into ModelResponse"""
return ModelResponse(raw=response)
def __call__(self, input_: str, **kwargs: Any) -> ModelResponse:
"""Calling the model with requests.post.
Args:
input_ (`str`):
The input string to the model.
Returns:
`dict`: A dictionary that contains the response of the model and
related
information (e.g. cost, time, the number of tokens, etc.).
Note:
`parse_func`, `fault_handler` and `max_retries` are reserved for
`_response_parse_decorator` to parse and check the response
generated by model wrapper. Their usages are listed as follows:
- `parse_func` is a callable function used to parse and check
the response generated by the model, which takes the response
as input.
- `max_retries` is the maximum number of retries when the
`parse_func` raise an exception.
- `fault_handler` is a callable function which is called
when the response generated by the model is invalid after
`max_retries` retries.
"""
# step1: prepare keyword arguments
print(**kwargs)
post_args = {**self.post_args, **kwargs}
self.json_args['model_name'] = 'gewu_14b_v1'
self.json_args['model'] = 'gewu_14b_v1'
request_kwargs = {
"url": self.api_url,
"json": {self.messages_key: input_, **self.json_args},
"headers": self.headers or {},
**post_args,
}
# step2: prepare post requests
for i in range(1, self.max_retries + 1):
response = requests.post(**request_kwargs)
if response.status_code == requests.codes.ok:
break
if i < self.max_retries:
logger.warning(
f"Failed to call the model with "
f"requests.codes == {response.status_code}, retry "
f"{i + 1}/{self.max_retries} times",
)
time.sleep(i * self.retry_interval)
# step3: record model invocation
# record the model api invocation, which will be skipped if
# `FileManager.save_api_invocation` is `False`
self._save_model_invocation(
arguments=request_kwargs,
response=response.json(),
)
# step4: parse the response
if response.status_code == requests.codes.ok:
return self._parse_response(response.json())
else:
logger.error(json.dumps(request_kwargs, indent=4))
raise RuntimeError(
f"Failed to call the model with {response.json()}",
)
class GewuAPIChatWrapper(GewuModelWrapper): """A post api model wrapper compatible with openai chat, e.g., vLLM, FastChat."""
model_type: str = "ge_api_chat"
def _parse_response(self, response: dict) -> ModelResponse:
return ModelResponse(
text=response["choices"][0]["message"][
"content"
],
)
def format(
self,
*args: Union[Msg, Sequence[Msg]],
) -> Union[List[dict]]:
"""Format the input messages into a list of dict, which is
compatible to OpenAI Chat API.
Args:
args (`Union[Msg, Sequence[Msg]]`):
The input arguments to be formatted, where each argument
should be a `Msg` object, or a list of `Msg` objects.
In distribution, placeholder is also allowed.
Returns:
`Union[List[dict]]`:
The formatted messages.
"""
# Format according to the potential model field in the json_args
return ModelWrapperBase.format_for_common_chat_models(*args)
if name == 'main':
pass
Background
Considering the PostAIChatWrapper will be used for many different models (e.g. gpt-4, gemini, glm-4, and so on), and OpenAI API can be used for different model services, we should choose prompt strategies in both OpenAIChatWrapper and PostAPIChatWrapper according to the model name.
Description
format_for_common_chat_models
inModelWrapperBase
class;In this PR, we modified into to
with only system messages at the beginning
formatted_prompt = [ { "role": "user", "content": "You're a helpful assistant\n" }, ]