open-compass / T-Eval

[ACL2024] T-Eval: Evaluating Tool Utilization Capability of Large Language Models Step by Step
https://open-compass.github.io/T-Eval/
Apache License 2.0
210 stars 13 forks source link

API model #34

Open Fenglly opened 6 months ago

Fenglly commented 6 months ago

If I want to test the qwen model with the API, can I just use the GPTAPI class and replace the model URL with the qwen one?

zehuichen123 commented 6 months ago

Sry, we do not support customized API models currently. You may initialize a new model from BaseAPIModel in lagent and write the code by yourself. Btw, we will release a template file for how to customize an API model class in lagent recently.

zehuichen123 commented 6 months ago

Here is an unfinished reference code

import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, wait
from logging import getLogger
from threading import Lock
from typing import Dict, List, Optional, Union
import requests

from .base_api import BaseAPIModel, APITemplateParser

class CustomAPI(BaseAPIModel):
    """Model wrapper around Custom API models.

    Args:
        model_url (str): The url of the requested API model.
        query_per_second (int): The maximum queries allowed per second
            between two consecutive calls of the API. Defaults to 1.
        retry (int): Number of retires if the API call fails. Defaults to 2.
        key (str or List[str]): key(s) for API model. In particular, when it
            is set to "ENV", If it's a list, the keys will be used in round-robin
            manner. Defaults to 'ENV'.
        meta_template (Dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.
        gen_params: Default generation configuration which could be overridden
            on the fly of generation.
    """

    def __init__(self,
                 model_type: str,
                 model_url: str,
                 query_per_second: int = 1,
                 retry: int = 2,
                 key: Union[str, List[str]],
                 meta_template: Optional[Dict] = [
                     dict(role='system', api_role='system'),
                     dict(role='user', api_role='user'),
                     dict(role='assistant', api_role='assistant')
                 ],
                 **gen_params):
        self.url = model_url
        super().__init__(
            model_type=model_type,
            meta_template=meta_template,
            query_per_second=query_per_second,
            retry=retry,
            **gen_params)
        self.logger = getLogger(__name__)
        if key is None:
            self.keys = None
        elif isinstance(key, str):
            self.keys = [key]
        else:
            self.keys = key

    def _generate(self,
                  inputs: str or List,
                  max_out_len: int = None,
                  temperature: float = None) -> str:
        """Generate results given a list of inputs.

        Args:
            inputs (str or List): A string or PromptDict.
                The PromptDict should be organized in OpenCompass'
                API format.
            max_out_len (int): The maximum length of the output.
            temperature (float): What sampling temperature to use,
                between 0 and 2. Higher values like 0.8 will make the output
                more random, while lower values like 0.2 will make it more
                focused and deterministic.

        Returns:
            str: The generated string.
        """
        assert isinstance(inputs, (str))
        max_num_retries = 0
        while max_num_retries < self.retry:

            header = {
                'content-type': 'application/json',
            }
            self._session_id = (self._session_id + 1) % 1000000

            try:
                data = dict(
                    model=self.path,
                    session_id=self._session_id,
                    prompt=inputs,
                    sequence_start=True,
                    sequence_end=True,
                    max_tokens=max_out_len,
                )
                raw_response = requests.post(
                    self.url, headers=header, data=json.dumps(data))
            except requests.ConnectionError:
                print('Got connection error, retrying...')
                max_num_retries += 1
                continue
            try:
                response = raw_response.json()
            except requests.JSONDecodeError:
                print('JsonDecode error, got', str(raw_response.content))
                max_num_retries += 1
                continue
            try:
                if 'completion' in self.url:
                    return response['choices'][0]['text'].strip()
                else:
                    return response['text'].strip()
            except KeyError:
                max_num_retries += 1
                pass

        raise RuntimeError('Calling API model failed after retrying for '
                           f'{max_num_retries} times. Check the logs for '
                           'details.')
Fenglly commented 6 months ago

Thanks! I will try.