zeno-ml / zeno-build

Build, evaluate, understand, and fix LLM-based apps
MIT License
482 stars 33 forks source link

Suggetion for the error handler of `_throttled_openai_chat_completion_acreate` #169

Closed zhaochenyang20 closed 1 year ago

zhaochenyang20 commented 1 year ago

https://github.com/zeno-ml/zeno-build/blob/3674429930e8fe17f23ed9d91c2b5307fb9ed668/zeno_build/models/providers/openai_utils.py#L113

I was using zeno to call OpenAI API in a batch. I find this function could be modified to a better form:

OPENAI_ERRORS = (
    openai.error.APIError,
    openai.error.TimeoutError,
    openai.error.Timeout,
    openai.error.RateLimitError,
    openai.error.ServiceUnavailableError,
    openai.error.InvalidRequestError,
    openai.error.APIConnectionError,
    openai.error.APIError,
)

async def _throttled_openai_chat_completion_acreate(
    model: str,
    messages: list[dict[str, str]],
    temperature: float,
    max_tokens: int,
    top_p: float,
    limiter: aiolimiter.AsyncLimiter,
) -> dict[str, Any]:
    # This function is modified from https://github.com/zeno-ml/zeno-build/blob/main/zeno_build/models/providers/openai_utils.py#L113 # noqa E501
    async with limiter:
        for _ in range(3):
            try:
                return await openai.ChatCompletion.acreate(
                    model=model,
                    messages=messages,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    top_p=top_p,
                )
            except openai.error.InvalidRequestError:
                logging.warning(
                    "OpenAI API Invalid Request: Prompt was filtered"
                )
                return {
                    "choices": [
                        {
                            "message": {
                                "content": "Invalid Request: Prompt was filtered"  # noqa E501
                            }
                        }
                    ]
                }
            except tuple(ERROR_MESSAGES.keys()) as e:
                if isinstance(
                    e, (error.ServiceUnavailableError, error.APIError)
                ):
                    logging.warning(ERROR_MESSAGES[type(e)].format(e=e))
                else:
                    logging.warning(ERROR_MESSAGES[type(e)])
                await asyncio.sleep(10)
        return {"choices": [{"message": {"content": ""}}]}
zhaochenyang20 commented 1 year ago

BTW, there are actually more useful parameters for openai.ChatCompletion.acreate. Adding them would be better. 🤔

https://platform.openai.com/docs/api-reference/chat/create

neubig commented 1 year ago

Thanks @zhaochenyang20 ! Please send a pull request and I'll review it. If possible it would be nice if you could send it for both chat_completion and completion

zhaochenyang20 commented 1 year ago

Sure. I may have time to do this this weekend. 😂

zhaochenyang20 commented 1 year ago

BTW, there is actually no TimeoutError 🤔😂

module 'openai.error' has no attribute 'TimeoutError'
zhaochenyang20 commented 1 year ago

Here goes the PR: https://github.com/zeno-ml/zeno-build/pull/170

neubig commented 1 year ago

Fixed by #171