modelscope / agentscope

Start building LLM-empowered multi-agent applications in an easier way.
https://doc.agentscope.io/
Apache License 2.0
4.82k stars 296 forks source link

Implement model-oriented format function in OpenAI and Post API chat wrapper #381

Closed DavdGao closed 1 month ago

DavdGao commented 1 month ago

Background

Considering the PostAIChatWrapper will be used for many different models (e.g. gpt-4, gemini, glm-4, and so on), and OpenAI API can be used for different model services, we should choose prompt strategies in both OpenAIChatWrapper and PostAPIChatWrapper according to the model name.

Description

with only system messages at the beginning

formatted_prompt = [ { "role": "user", "content": "You're a helpful assistant\n" }, ]



- Modify unit test accordingly.

## Checklist

Please check the following items before code is ready to be reviewed.

- [x]  Code has passed all tests
- [x]  Docstrings have been added/updated in Google Style
- [x]  Documentation has been updated
- [x]  Code is ready for review
yawzhe commented 2 weeks ago

自己定义的模型包装器,发现里面就是启动一个参数传参的功能,里面也没有流式,自己写根本就无法起作用,我在里面判断请求方式,自己加stream,发现他根本就不起作用,而且,即便是删除reque_argwargs后面的所有代码,也不影响,即便是有错误,也不会有什么用,这个自定义的模型包装器,到底怎么定义,我只能非流式传输,流式的安装openai改了判断stream,使用迭代器,发现根本就不起作用,也不报错,始终是非流式输出#!/usr/bin/env python

-- coding:utf-8 --

@Time : 2024/8/28 9:54

@File : local_model.py

from abc import ABC from typing import Any, Union, Sequence, List from loguru import logger

from agentscope.models import ModelWrapperBase from agentscope.models import ModelResponse import requests from agentscope.constants import _DEFAULT_MESSAGES_KEY,_DEFAULT_MAX_RETRIES,_DEFAULT_RETRY_INTERVAL from agentscope.message import Msg import time import json class GewuModelWrapper(ModelWrapperBase):

model_type = 'ge_api_chat'

def __init__(
        self,
        config_name="",
        model_name="",
        api_url="",
        headers: dict = None,
        max_length: int = 2048,
        timeout: int = 30,
        # json_args: dict = None,
        # post_args: dict = None,
        max_retries: int = _DEFAULT_MAX_RETRIES,
        messages_key: str = _DEFAULT_MESSAGES_KEY,
        retry_interval: int = _DEFAULT_RETRY_INTERVAL,
        **kwargs: Any,

) ->None:

    super().__init__(config_name=config_name, model_name=model_name)

    self.api_url = api_url
    self.headers = headers
    self.max_length = max_length
    self.timeout = timeout
    self.json_args = {}
    self.post_args = {}
    self.max_retries = max_retries
    self.messages_key = messages_key
    self.retry_interval = retry_interval

def _parse_response(self, response: dict) -> ModelResponse:
    """Parse the response json data into ModelResponse"""
    return ModelResponse(raw=response)

def __call__(self, input_: str, **kwargs: Any) -> ModelResponse:
    """Calling the model with requests.post.

    Args:
        input_ (`str`):
            The input string to the model.

    Returns:
        `dict`: A dictionary that contains the response of the model and
        related
        information (e.g. cost, time, the number of tokens, etc.).

    Note:
        `parse_func`, `fault_handler` and `max_retries` are reserved for
        `_response_parse_decorator` to parse and check the response
        generated by model wrapper. Their usages are listed as follows:
            - `parse_func` is a callable function used to parse and check
            the response generated by the model, which takes the response
            as input.
            - `max_retries` is the maximum number of retries when the
            `parse_func` raise an exception.
            - `fault_handler` is a callable function which is called
            when the response generated by the model is invalid after
            `max_retries` retries.
    """
    # step1: prepare keyword arguments
    print(**kwargs)
    post_args = {**self.post_args, **kwargs}
    self.json_args['model_name'] = 'gewu_14b_v1'
    self.json_args['model'] = 'gewu_14b_v1'

    request_kwargs = {
        "url": self.api_url,
        "json": {self.messages_key: input_, **self.json_args},
        "headers": self.headers or {},
        **post_args,
    }

    # step2: prepare post requests
    for i in range(1, self.max_retries + 1):
        response = requests.post(**request_kwargs)
        if response.status_code == requests.codes.ok:
            break
        if i < self.max_retries:
            logger.warning(
                f"Failed to call the model with "
                f"requests.codes == {response.status_code}, retry "
                f"{i + 1}/{self.max_retries} times",
            )
            time.sleep(i * self.retry_interval)

    # step3: record model invocation
    # record the model api invocation, which will be skipped if
    # `FileManager.save_api_invocation` is `False`
    self._save_model_invocation(
        arguments=request_kwargs,
        response=response.json(),
    )

    # step4: parse the response
    if response.status_code == requests.codes.ok:
        return self._parse_response(response.json())
    else:
        logger.error(json.dumps(request_kwargs, indent=4))
        raise RuntimeError(
            f"Failed to call the model with {response.json()}",
        )

class GewuAPIChatWrapper(GewuModelWrapper): """A post api model wrapper compatible with openai chat, e.g., vLLM, FastChat."""

model_type: str = "ge_api_chat"

def _parse_response(self, response: dict) -> ModelResponse:
    return ModelResponse(
        text=response["choices"][0]["message"][
            "content"
        ],
    )
def format(
        self,
        *args: Union[Msg, Sequence[Msg]],
) -> Union[List[dict]]:
    """Format the input messages into a list of dict, which is
    compatible to OpenAI Chat API.

    Args:
        args (`Union[Msg, Sequence[Msg]]`):
            The input arguments to be formatted, where each argument
            should be a `Msg` object, or a list of `Msg` objects.
            In distribution, placeholder is also allowed.

    Returns:
        `Union[List[dict]]`:
            The formatted messages.
    """
    # Format according to the potential model field in the json_args

    return ModelWrapperBase.format_for_common_chat_models(*args)

if name == 'main':

pass