jxnl / instructor

structured outputs for llms
https://python.useinstructor.com/
MIT License
7.54k stars 602 forks source link

Adding a lightweight prompt abstraction to the SchemaClass #24

Closed jxnl closed 1 year ago

jxnl commented 1 year ago

Sure! Here's the updated proposal where PromptConfig has the model as a required argument and all other attributes as optional. The default model is set to "gpt3.5-turbo-0613":

from pydantic import BaseModel
from typing import Optional

class OpenAISchema(BaseModel):
    class PromptConfig:
        model: str = "gpt3.5-turbo-0613"
        system: Optional[str]
        message: Optional[str]
        temperature: Optional[float]
        max_tokens: Optional[int]

    @classmethod
    def from_response(cls, response):
        # Implementation based on the actual response format.

    @classmethod
    def create(cls, message=None, *args, force_function=False, **kwargs):
        messages = kwargs.get("messages", [])

        if not messages and hasattr(cls, "PromptConfig"):
            if cls.PromptConfig.system:
                messages.append({
                    "role": "system",
                    "content": cls.PromptConfig.system
                })
            if cls.PromptConfig.message:
                messages.append({
                    "role": "user",
                    "content": cls.PromptConfig.message
                })

        if message:
            messages.append({
                "role": "user",
                "content": message
            })

        if force_function:
            kwargs['function_call'] = {"name": cls.openai_schema["name"]}

        kwargs['messages'] = messages

        if hasattr(cls, "PromptConfig"):
            kwargs.setdefault('model', cls.PromptConfig.model)
            kwargs.setdefault('temperature', cls.PromptConfig.temperature)
            kwargs.setdefault('max_tokens', cls.PromptConfig.max_tokens)

        completion = openai.ChatCompletion.create(
            functions=[cls.openai_schema],
            **kwargs
        )
        return cls.from_response(completion)

class Search(OpenAISchema):
    # Implementation remains the same

class MultiSearch(OpenAISchema):
    class PromptConfig:
        system = "You are a capable algorithm designed to correctly segment search requests."
        message = "Correctly segment the following search request"
        model = "gpt3.5-turbo-0613"
        temperature = 0.5
        max_tokens = 1000

    # Implementation remains the same

# Example of usage:
queries = MultiSearch.create(
    "Please send me the video from last week about the investment case study and also documents about your GPDR policy."
)
queries.execute()

This revision makes the PromptConfig more flexible and easier to use with the default model set and all other parameters as optional. This configuration can be overridden on a per-class basis, as shown in the MultiSearch.PromptConfig example.

fpingham commented 1 year ago

Hey! Saw the Twitter thread on abstractions for prompts. Would love to help! Is this along the lines of what you were thinking of? Trying to understand how you were thinking the abstractions might look like.

from pydantic import BaseModel
from typing import Optional, List
from enhancements import chain_of_thought, answer_correctly
from config import FEW_SHOT_EXAMPLES, BASE_SYSTEM_MESSAGE

ENHANCEMENTS = {
    'chain_of_thought': chain_of_thought, # adds 'Let's think step by step' to assistant message etc.
    'answer_correctly': answer_correctly, # adds 'Your task is to provide a correct answer to the user, which you make sure is correct' to system message etc.
    ...
}

def enhance_prompts(
        system_message: Optional[str],
        user_message: Optional[str],
        assistant_message: Optional[str],
        few_shot_examples : Optional[List[str]],
        enhancements : Optional[List[str]],
        ):

    system_message = system_message or BASE_SYSTEM_MESSAGE

    messages = [system_message, user_message, assistant_message]

    for en in enhancements:
        messages = ENHANCEMENTS[en](messages)

    if few_shot_examples:
        messages = add_few_shot(messages, few_shot_examples)

    return messages

class OpenAISchema(BaseModel):
    class ModelConfig:
        model: str = "gpt3.5-turbo-0613"
        temperature: Optional[float]
        max_tokens: Optional[int]

    class PromptConfig:
        system_message: Optional[str]
        user_message: Optional[str]
        assistant_message: Optional[str]
        enhancements = Optional[List[str]]
        few_shot = Optional[List[str]]

    @classmethod
    def from_response(cls, response):
        # Implementation based on the actual response format.

    @classmethod
    def create(cls, message=None, *args, force_function=False, **kwargs):
        messages = kwargs.get("messages", [])

        if not messages and hasattr(cls, "PromptConfig"):
            messages.append(enhance_prompts(**cls.PromptConfig))

        if message:
            messages.append({
                "role": "user",
                "content": message
            })

        if force_function:
            kwargs['function_call'] = {"name": cls.openai_schema["name"]}

        kwargs['messages'] = messages

        if hasattr(cls, "PromptConfig"):
            kwargs.setdefault('model', cls.PromptConfig.model)
            kwargs.setdefault('temperature', cls.PromptConfig.temperature)
            kwargs.setdefault('max_tokens', cls.PromptConfig.max_tokens)

        completion = openai.ChatCompletion.create(
            functions=[cls.openai_schema],
            **kwargs
        )
        return cls.from_response(completion)
jxnl commented 1 year ago

i have a prototype now that looks like

    from functions import MultiSearch
    from pprint import pprint

    task = (
        ChatCompletion(name="Acme Inc Email Segmentation", model="gpt3.5-turbo-0613")
        | ExpertSystem(task="Segment emails into search queries")
        | MultiSearch
        | ChainOfThought()
        | TaggedMessage(tag="email", content="Segment emails into search queries")
    )

    obj = task.create()
    assert isinstance(obj, MultiSearch)