stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy-docs.vercel.app/
MIT License
18.49k stars 1.42k forks source link

Request for OpenAI Structured Output Support #1365

Open Gsizm opened 3 months ago

Gsizm commented 3 months ago

Hello DSPy Team,

I want to request the addition of support for OpenAI's structured output in your library.

OpenAI recently introduced structured outputs in their API, which seems to guarantee the generation of 100% valid complex JSONs (https://openai.com/index/introducing-structured-outputs-in-the-api/). This feature will remove a metric ton of headache by solving many challenges related to JSON validation and parsing.

Could you please consider adding this support as soon as possible? This enhancement would be super beneficial for the whole DSPy community.

grant-d commented 2 months ago

You can already use it:

call_config: dict = {}
call_config["response_format"] = {
    "type": "json_schema",  # json_object | text
    "json_schema": {
        "name": "my_schema",
        "strict": True,
        "schema": {
            "type": "object",
            "properties": {"foo": {"type": "string"}, "bar": {"type": "number"}},
            "additionalProperties": True,
        },
    },
}

predict = dspy.Predict(
    PredictSignature,
    **call_config,
)

# Output
# {"foo":....}
grant-d commented 2 months ago

Even though this (above) works, I am now running into a problem that when dspy rewrites the prompts, it sometimes elides the word json, resulting in the (thus, expected) error: messages' must contain the word 'json' in some form, to use 'response_format' of type 'json_object'.

Is there a way to force dspy to always include certain text. eg, would CoT.hint stay immutable or is everything subject to the mutation algo?

[edit]: TypedPredictor does not expose **config so I cannot __init__ it the way I show in my code above. Hence, my question

imranarshad commented 2 months ago

I was able to extend Predict and modify new Adopter to come up with a way to work with OpenAI structured output and also Claud way of prompts. It has in-built reasoning along with reflection.

Prompts inputs are XML based, and output is Pydantic Model

import json
import logging
import textwrap
import traceback
import re
from typing import Union, Any

import dsp
import dspy
from dsp import BaseTemplate
from dspy import ensure_signature, signature_to_template, make_signature
from dsp.primitives.demonstrate import Example
from dspy.predict.parameter import Parameter
from openai.lib._pydantic import to_strict_json_schema
from pydantic import BaseModel, Field

def slugify(string, delimiter='_'):
    return re.sub(r'[\W_]+', delimiter, string.encode('ascii', errors='ignore').decode()).strip(delimiter).lower()

class PredictX(dspy.Predict, Parameter):
    def __init__(self, signature, response_format: dict, reflect, auto_reflect, reasoning, max_retries, **config):
        config["response_format"] = response_format
        self.reflect = reflect
        self.reasoning = reasoning
        self.max_retries = max_retries
        self.auto_reflect = auto_reflect
        self.reflecting = False
        super().__init__(signature, **config)

    def forward(self, **kwargs):
        assert not dsp.settings.compiling, "It's no longer ever the case that .compiling is True"

        # Extract the three privileged keyword arguments.
        new_signature = ensure_signature(kwargs.pop("new_signature", None))
        signature = ensure_signature(kwargs.pop("signature", self.signature))
        demos = kwargs.pop("demos", self.demos)
        config = dict(**self.config, **kwargs.pop("config", {}))

        # Get the right LM to use.
        lm = kwargs.pop("lm", self.lm) or dsp.settings.lm
        assert lm is not None, "No LM is loaded."

        # If temperature is 0.0 but its n > 1, set temperature to 0.7.
        temperature = config.get("temperature")
        temperature = lm.kwargs["temperature"] if temperature is None else temperature

        num_generations = config.get("n")
        if num_generations is None:
            num_generations = lm.kwargs.get("n", lm.kwargs.get("num_generations", 1))

        if (temperature is None or temperature <= 0.15) and num_generations > 1:
            config["temperature"] = 0.7
            # print(f"#> Setting temperature to 0.7 since n={num_generations} and prior temperature={temperature}.")

        if new_signature is not None:
            signature = new_signature

        if not all(k in kwargs for k in signature.input_fields):
            present = [k for k in signature.input_fields if k in kwargs]
            missing = [k for k in signature.input_fields if k not in kwargs]
            print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.")
        _demos = dsp.Example(demos=demos, **kwargs)

        if self.reflecting:
            config['system'] = """You are expert clinical documentation reviewer. You critically reviewed the AI-generated response and identified issues."""
        else:
            config['system'] = """You are an helpful assistant, your task is to understand inputs provided in INPUTS xml tag and produce response as per OUTPUT_SCHEMA. For given inputs, you write your detailed objective data-driven reasoning in StructuredOutput.reasoning field, to be sure you understand what you are doing."""

        # TODO: Add a retry mechanism here.
        while self.max_retries > 0:
            try:
                pred = structured_output_generate(lm, signature, _demos, **config)
                # pprint(f"#### Response Ratting: {pred.rate_your_response}")
                # print(f"#### Response Critique: {pred.criticise_response}")
                break
            except Exception as e:
                self.max_retries -= 1
                print(f"Trying again. {self.max_retries} retries left.")
                if self.max_retries == 0:
                    raise e
        if self.reflect or (self.auto_reflect and pred.rate_your_response <= 3):
            self.reflect = False
            self.auto_reflect = False
            self.reflecting = True
            # pprint(f"Quality of Response: {pred.criticise_response}.\n\nReflecting...")

            instructions = textwrap.dedent(f"""You are tasked with reviewing and reflecting on an AI-generated response provided in LLM_RESPONSE to find mistakes and evaluate how well the ORIGINAL_PROMPT followed. Your goal is to identify issues, explain how to fix them, and rewrite the correct response.

Here are the inputs you will be working with:

:prompt_inputs

1. Adherence to the original instructions provided in ORIGINAL_PROMPT
2. Make use of critique_on_response to identify issues in the response as well.
2. Accuracy of information
3. Completeness of the response
4. Proper use of any specified formats or structures

List each issue as numbered points:
1. What the issue is. Be specific with citations of the data
2. Describe how you will fix it

In StructuredOutput.reasoning, list each issue with evidence and its corresponding fix on separate lines. Do not make assumptions unless explicitly stated in the inputs.

For corrected response, do not copy text from the LLM RESPONSE; instead, create a new, improved version that addresses all the identified issues.
""")
            del config['system']
            template = signature_to_template(signature, adapter=StructuredOutputAdapter)
            modified_kwargs = {
                'original_prompt': template(_demos, show_guidelines=False),
                'LLM_Response': pred.content.dict(),
                # 'critique_on_response': pred.criticise_response
            }

            _signature = {
                'original_prompt': dspy.InputField(desc="Original Prompt"),
                'LLM_Response': dspy.InputField(desc="LLM Response"),
                # 'critique_on_response': dspy.InputField(desc="self critique on response"),
                'output_with_reasoning': signature.fields['output_with_reasoning']
            }
            new_signature = make_signature(_signature, instructions=instructions, signature_name=f"Reflect.{signature.__name__}")
            reflected_pred = self.forward(new_signature=new_signature, **modified_kwargs)
            reflected_pred.reasoning = pred.reasoning + "\n\nUpon reflecting:\n" + reflected_pred.reasoning
            pred = reflected_pred

        if kwargs.pop("_trace", True) and dsp.settings.trace is not None:
            trace = dsp.settings.trace
            trace.append((self, {**kwargs}, pred))

        return pred

def TypedCOTPredict(signature, out_type: BaseModel, reflect: bool=False, auto_reflect=False, reasoning:str = None, *, max_retries=3) -> dspy.Module:
    class StructuredModel(BaseModel):
        reasoning: str = Field(..., description=f"Your observations about the data in INPUTS and reasoning with citations, and how you will follow the task instructions. Use new lines, instead of writing all in one sentence",)
        content: out_type
        # criticise_response: str = Field(description="critique your response. Identify any mistakes, inaccuracies, or deviations from the instructions with evidence.")
        # rate_your_response: int = Field(..., description="Given your critique, rate your response on a scale of 1 to 5, where 1 is the worst and 5 is the best. More mistakes, inaccuracies, or deviations from the instructions will result in a lower score.")

        def get(self, field_name: str):
            return self.dict()

    response_format = {
        "type": "json_schema",
        "json_schema": {
            "schema": to_strict_json_schema(StructuredModel),
            "name": out_type.__name__,
            "strict": True,
        }
    }
    signature = signature.append('output_with_reasoning', dspy.OutputField(), StructuredModel)
    new_signature = make_signature(signature.model_fields, signature.instructions, signature_name=out_type.__name__)
    return PredictX(new_signature, response_format=response_format, reflect=reflect, auto_reflect=auto_reflect, reasoning=reasoning, max_retries=max_retries)

class StructuredOutputAdapter(BaseTemplate):
    def query(self, example: Example, is_demo: bool = False) -> str:
        """Retrieves the input variables from the example and formats them into a query string."""
        result: list[str] = []

        # If not a demo, find the last field that doesn't have a value set in `example` and set it to ""
        # This creates the "Output:" prefix at the end of the prompt.
        if not is_demo:
            has_value = [
                field.input_variable in example
                and example[field.input_variable] is not None
                and example[field.input_variable] != ""
                for field in self.fields
            ]

            if not any(has_value):
                assert False, "No input variables found in the example"

            for i in range(1, len(has_value)):
                if has_value[i - 1] and not any(has_value[i:]):
                    example[self.fields[i].input_variable] = ""
                    break

        for field in self.fields[:-1]:
            if field.input_variable in example and example[field.input_variable] is not None:
                if field.input_variable in self.format_handlers:
                    format_handler = self.format_handlers[field.input_variable]
                else:
                    def format_handler(x):
                        return str(x).strip()

                formatted_value = format_handler(example[field.input_variable])

                result.append(f"<{slugify(field.name).upper()} desc=\"{field.description}\">\n{formatted_value}\n</{slugify(field.name).upper()}>")

        return "<INPUTS>\n" + "\n\n".join([r for r in result if r]) + "\n</INPUTS>"

    def guidelines(self, show_guidelines=True) -> str:
        """Returns the task guidelines as described in the lm prompt"""
        if (not show_guidelines) or (hasattr(dsp.settings, "show_guidelines") and not dsp.settings.show_guidelines):
            return ""

        result = "Follow the following format.\n\n"

        example = dsp.Example()
        for field in self.fields:
            example[field.input_variable] = field.description
        example.augmented = True

        result += self.query(example)
        return result

    def extract(
            self,
            example: Union[Example, dict[str, Any]],
            raw_pred: str,
    ) -> Example:
        """Extracts the answer from the LM raw prediction using the template structure

        Args:
            example (Union[Example, dict[str, Any]]): Contains the input variables that raw_pred was completed on.
            raw_pred (str): LM generated string

        Returns:
            Example: The example with the output variables filled in
        """

        example = dsp.Example(example)
        try:
            example['output_with_reasoning'] = json.loads(raw_pred)
        except json.JSONDecodeError:
            logging.error(f"Failed to decode JSON: {raw_pred}")
            raise json.JSONDecodeError

        return example

    def __call__(self, example, show_guidelines=True) -> str:
        example = dsp.Example(example)
        output_fields = []

        rdemos = [
            self.query(demo, is_demo=True)
            for demo in example.demos
            if (
                    (not demo.get("augmented", False))
                    and (  # validate that the training example has the same primitive input var as the template
                            self.fields[-1].input_variable in demo and demo[self.fields[-1].input_variable] is not None
                    )
            )
        ]

        ademos = [self.query(demo, is_demo=True) for demo in example.demos if demo.get("augmented", False)]

        # Move the rdemos to ademos if rdemo has all the fields filled in
        rdemos_ = []
        new_ademos = []
        for rdemo in rdemos:
            if all((field.name in rdemo) for field in self.fields if field.input_variable in example):
                new_ademos.append(rdemo)
            else:
                rdemos_.append(rdemo)

        ademos = new_ademos + ademos
        rdemos = rdemos_

        example["augmented"] = True

        query = self.query(example)
        if ':prompt_inputs' in self.instructions:
            self.instructions = self.instructions.replace(':prompt_inputs', query)
            parts = [self.instructions, *rdemos, self.guidelines(show_guidelines), *ademos]
        else:
            parts = [self.instructions, *rdemos, self.guidelines(show_guidelines), *ademos, query, ]

        prompt = "\n\n---\n\n".join([p.strip() for p in parts if p])

        return prompt.strip()

def structured_output_generate(lm, signature, example, **kwargs):
    template = signature_to_template(signature, adapter=StructuredOutputAdapter)
    prompt = template(example, show_guidelines=False)
    # Generate and extract the fields.
    if 'openai' in lm.provider:
        lm.system_prompt = kwargs['system']
        del kwargs['system']
        completions = lm(prompt, **kwargs)
    else:
        # kwargs['system'] = signature.instructions
        # new_signature = make_signature(signature.model_fields, '' , signature_name=signature.__name__)
        # template = signature_to_template(new_signature, adapter=StructuredOutputAdapter)
        prompt += f"\n\nYour response should strictly follow JSON SCHEMA as given below - a valid json object on one line. Ensure that JSON output uses double quotes for strings, escape special characters like double quotes, backslashes, newlines, and tabs so it can be parsed by json.loads. DO NOT wrap your response in ```json and ```. \n<OUTPUT_SCHEMA>\n" + json.dumps(kwargs['response_format']['json_schema']['schema'], indent=2) + f"\n</OUTPUT_SCHEMA>"
        del kwargs['response_format']

        completions = lm(prompt, **kwargs)

    completions = [template.extract(example, p) for p in completions]

    assert all(set(signature.input_fields).issubset(set(c.keys())) for c in completions), "Missing input keys."
    _type = signature.fields['output_with_reasoning'].annotation

    _completions = [_type(**value) for c in completions
                    for key, value in c.items() if
                    key in signature.output_fields]

    return _completions[0]

Usage

class Movie(BaseModel):
    class Actor(BaseModel):
        name: str
        gender: Literal['male', 'female']
        role: str = Field(..., description="Role played by the actor in the movie (e.g. lead, supporting, cameo)")

    title: str
    year: int
    actors: list[Actor] = Field(..., description="List of actors in the movie")

class ExtractMovieDetails(dspy.Signature):
    """Extracts details about a movie based on the user query.

    You will be working with the following inputs:
    :prompt_inputs

    Your response shouldbe factually accurate and relevant to the user query.
    """
    query: str = dspy.InputField(description="user query about the movie")
    released_after: int = dspy.InputField(description="year after which the movie was released")

prediction = TypedCOTPredict(ExtractMovieDetails, out_type=Movie)(query="in which last movie Tom Cruise did not play the lead role? and who was the lead?", released_after=2000)

print(prediction) 

You will notice, I have not used dspy.OutputField in signature because my use case required me use choose output schema on the fly.

Not my best work, but it is working for me in production with OpenAI structured output and Cluade.

@okhat please feel free to point out any areas it can be improved

DISCLAIMER: I have not tested it with any optimizers yet, plan to do this week.

haznai commented 2 months ago

@imranarshad fascinating code! any luck of using the modules with optimizers?

NumberChiffre commented 1 month ago

@imranarshad Great work! Does this work with MIPROv2?

manubhardwaj commented 3 weeks ago

With OpenAI, it looks like dspy/LiteLLM use chat.completions.create(...) under the hood at this point? OpenAI's Structured Outputs introduction recommends the usage of the chat.completions.parse(...) function instead.

This is how I "forced" an OpenAI-compatible response_format into dspy 2.5, in order to get back a structured object in the response. YMMV.

import pydantic
import dspy

def get_schema(cls: type): # cls is a pydantic.BaseModel class
        schema = {
            'name': cls.__name__,
            'schema': cls.schema()
        }
        schema['schema'].update({'additionalProperties': False})
        return schema

Usage:

class SQL(pydantic.BaseModel):
        sql: str

response_schema = get_schema(SQL)
response_format = { 'type': 'json_schema', 'json_schema': {'strict': True}}
response_format['json_schema'].update(response_schema)

lm = dspy.LM(model='openai/gpt-4o', response_format=response_format, api_key=...)
dspy.configure(lm=lm)

Result:

>>> import json
>>> json.loads(lm('SQL to describe a table named abc. Only the SQL.'))
[{'sql': 'DESCRIBE abc;'}]
shivamkhatri commented 2 weeks ago

hello, could this be please supported?