HazyResearch / manifest

Prompt programming with FMs.
Apache License 2.0
440 stars 46 forks source link

Support logprobs for OpenAI models #83

Closed bclyang closed 1 year ago

bclyang commented 1 year ago

Description of the feature request

I'd like to be able to pass the logprobs parameter to the OpenAI request endpoint. See API reference: https://platform.openai.com/docs/api-reference/completions/create.

russelnelson commented 1 year ago

To get the logprobs parameter, you can modify the OpenAIClient class by editing the openai.py file. (Leaving this here as a comment because I wasn't clear on the process for contributing. Happy to make a pull request if helpful.)

Here are the steps and code changes to include logprobs in the API request and then extract logprobs from the API response:

  1. Modify the PARAMS dictionary to include the logprobs parameter with a default value of None.
  2. Update the _run_completion and _arun_completion methods to include the logprobs parameter in the API request.
  3. Add a method called get_logprobs to extract logprobs from the API response.

Here's an example of how you insert those changes:

class OpenAIClient(Client):
    # ... (previous code)

    # Add the "logprobs" parameter to the PARAMS dictionary
    PARAMS = {
        # ... (previous parameters)
        "logprobs": ("logprobs", None),
    }

    # ... (other methods)

    def _run_completion(
        self, request_params: Dict[str, Any], retry_timeout: int
    ) -> Dict:
        """Execute completion request.

        Args:
            request_params: request params.
            retry_timeout: retry timeout.

        Returns:
            response as dict.
        """
        # Add the "logprobs" parameter to the API request
        if getattr(self, "logprobs") is not None:
            request_params["logprobs"] = getattr(self, "logprobs")

        response_dict = super()._run_completion(request_params, retry_timeout)
        return response_dict

    async def _arun_completion(
        self, request_params: Dict[str, Any], retry_timeout: int, batch_size: int
    ) -> Dict:
        """Async execute completion request.

        Args:
            request_params: request params.
            retry_timeout: retry timeout.
            batch_size: batch size for requests.

        Returns:
            response as dict.
        """
        # Add the "logprobs" parameter to the API request
        if getattr(self, "logprobs") is not None:
            request_params["logprobs"] = getattr(self, "logprobs")

        response_dict = await super()._arun_completion(
            request_params, retry_timeout, batch_size
        )
        return response_dict

    def get_logprobs(self, response: Dict[str, Any]) -> Optional[List[float]]:
        """
        Extract logprobs from the API response.

        Args:
            response: The API response dictionary.

        Returns:
            A list of logprobs for each choice or None if logprobs are not available.
        """
        if "choices" not in response:
            return None

        logprobs_list = []
        for choice in response["choices"]:
            if "logprobs" in choice:
                logprobs_list.append(choice["logprobs"])
            else:
                logprobs_list.append(None)

        return logprobs_list
russelnelson commented 1 year ago

An alternative solution for logprobs would be to make changes to the Request and Response classes. This would allow users to compare OpenAI's logprobs with other models.

lorr1 commented 1 year ago

Fixed when I refactored pydantic types in #84