stanfordnlp / dspy

DSPy: The framework for programming—not prompting—foundation models
https://dspy-docs.vercel.app/
MIT License
14.18k stars 1.08k forks source link

[Feature Request] Better Typed Predictors for Reliable JSON Generation #1001

Open StevenGuo42 opened 2 months ago

StevenGuo42 commented 2 months ago

The Typed Predictors output JSON as plain text. They do not use function calling method/template and do not enforce the output to be a valid JSON. This makes generating structured output challenging (see my previous issue #957, other related issues: #1024 ).

Some possible solutions:

AriMKatz commented 2 months ago

For the last, sglang or outlines. Instructor I don't think does deterministic generation.

Also, maybe some way to add hints/ few shot examples (this really helps haiku) if generation fails.

StevenGuo42 commented 2 months ago

@AriMKatz Thanks for the suggestion. I added both to the issue.

juicesharp commented 1 month ago

A half of responses failed validation actually contain a valid json object in it, but has a prefix like: "Here is my response: ```json VALID_JSON or "Written doc fotmmated as Json object {VALID JSON}. A simple ability to inject/customize a post-process of Prediction inside of Pipeline, for example use your own clean-up function, would have drastically mitigated the problem.

One possible approach would be to allow subclasses of dspy.Module to override/exend part of post process pipeline. if Module has a function like def _post_prediction_process(self, pred : Prediction) -> Prediction

class Module(BaseModule, metaclass=ProgramMeta):
    ….
    # first change
   def _post_prediction_process(self, prediction):
        return prediction

class TypedPredictor(dspy.Module):
    …
    def forward(self, **kwargs) -> dspy.Prediction:
            modified_kwargs = kwargs.copy()
            signature = self._prepare_signature()
            for try_i in range(self.max_retries):
                     result = self.predictor(**modified_kwargs, new_signature=signature)
                      # second change
                     result = self._post_prediction_process(result)
        …
     …

class WriteBrief(dspy.Module):
    def __init__(self):
        super().__init__()
        self.tp = dspy.TypedPredictor(WriteBriefSignature)
        # inject custom _post_prediction_process into the process pipeline
        self.tp._post_prediction_process = self._post_prediction_process

    def forward(self, requirements):
        return self.tp(requirements=requirements)

    def _post_prediction_process(self, prediction):
        # analyze and clean up prediction 
    …
        return prediction

Even beyond the discussed issue a such hook has lot of sense.

juicesharp commented 1 month ago

Another idea after the object was not deserialized correctly DSPY before doing a "rephrase" attempt could simple check the failed response with a help of regex something like:

def _is_valid_json_block(self, code):
        pattern = r"```[\w\s]*\n([\s\S]*?)\n```"
        match = search(pattern, code)
        if match:
            return True
        else:
            return False

def _extract_content_within_backticks(self, code):
        pattern = r"```[\w\s]*\n([\s\S]*?)\n```"
        match = search(pattern, code)
        if match:
            # Extract and return the content within the backticks
            return match.group(1)
        else:
            return None

and validate again against pydantic model in case the result is not None. The rather high chance of success and no extra call to LLM should justify a such default strategy.

mikeedjones commented 1 month ago

As opposed to looking for backticks maybe a more flexible approach would be looking for the longest valid json string? Can also implement any of these in TypedPredictor by overriding the model_validate_json class method - as that looks like what gets used to parse the model output.

Something like the below:

from typing import Type, Any
from pydantic import BaseModel, ValidationError
import json

# no point parsing a string which isn't longer than the total length of all the
# keys in the model
def get_min_length(model: Type[BaseModel]):
    min_length = 0
    for key, field in model.model_fields.items():
        if issubclass(field.annotation, BaseModel):
            min_length+=get_min_length(field.annotation)
        min_length += len(key)
    return min_length

class Address(BaseModel):
    street: str
    city: str
    state: str
    zip: str

class Company(BaseModel):
    name: str
    address: Address

class User(BaseModel):
    name: str
    email: str
    company: Company
    home_address: Address

    @classmethod
    def model_validate_json(
        cls,
        json_data: str,
        *,
        strict: bool | None = None,
        context: dict[str, Any] | None = None
    ) -> "User":
        __tracebackhide__ = True
        try:
            return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context)
        except ValidationError:
            min_length = get_min_length(cls)
            for substring_length in range(len(json_string), min_length-1, -1):
                for start in range(len(json_string)-substring_length+1):
                    substring = json_string[start:start+substring_length]
                    try:
                        res = cls.__pydantic_validator__.validate_json(substring, strict=strict, context=context)
                        return res
                    except ValidationError:
                        pass
        raise ValueError("Could not find valid json")

dummy_data = {
    "name": "John Doe",
    "email": "johndoe@example.com",
    "company": {
        "name": "Example Corp",
        "address": {
            "street": "123 Example St",
            "city": "Example City",
            "state": "EX",
            "zip": "12345"
        }
    },
    "home_address": {
        "street": "456 Home St",
        "city": "Home City",
        "state": "HM",
        "zip": "67890"
    }
}

json_string  = f"""
No problem, here is the json you requested:
{json.dumps(dummy_data, indent=4)}
Hope that helps!
"""

User.model_validate_json(json_string)
juicesharp commented 1 month ago

As opposed to looking for backticks maybe a more flexible approach would be looking for the longest valid json string? Can also implement any of these in TypedPredictor by overriding the model_validate_json class method - as that looks like what gets used to parse the model output.

Something like the below:

from typing import Type, Any
from pydantic import BaseModel, ValidationError
import json

# no point parsing a string which isn't longer than the total length of all the
# keys in the model
def get_min_length(model: Type[BaseModel]):
    min_length = 0
    for key, field in model.model_fields.items():
        if issubclass(field.annotation, BaseModel):
            min_length+=get_min_length(field.annotation)
        min_length += len(key)
    return min_length

class Address(BaseModel):
    street: str
    city: str
    state: str
    zip: str

class Company(BaseModel):
    name: str
    address: Address

class User(BaseModel):
    name: str
    email: str
    company: Company
    home_address: Address

    @classmethod
    def model_validate_json(
        cls,
        json_data: str,
        *,
        strict: bool | None = None,
        context: dict[str, Any] | None = None
    ) -> "User":
        __tracebackhide__ = True
        try:
            return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context)
        except ValidationError:
            min_length = get_min_length(cls)
            for substring_length in range(len(json_string), min_length-1, -1):
                for start in range(len(json_string)-substring_length+1):
                    substring = json_string[start:start+substring_length]
                    try:
                        res = cls.__pydantic_validator__.validate_json(substring, strict=strict, context=context)
                        return res
                    except ValidationError:
                        pass
        raise ValueError("Could not find valid json")

dummy_data = {
    "name": "John Doe",
    "email": "johndoe@example.com",
    "company": {
        "name": "Example Corp",
        "address": {
            "street": "123 Example St",
            "city": "Example City",
            "state": "EX",
            "zip": "12345"
        }
    },
    "home_address": {
        "street": "456 Home St",
        "city": "Home City",
        "state": "HM",
        "zip": "67890"
    }
}

json_string  = f"""
No problem, here is the json you requested:
{json.dumps(dummy_data, indent=4)}
Hope that helps!
"""

User.model_validate_json(json_string)

Sure, I gave just an example of a strategy that may pay back if has been implemented in the DSPY. A Small feature that has decent amount of value in it.

P.S. Unfortunately looking into the codebase of DSPY that's getting more obvious people who wrote DSPY at first place never ever thought that people may want to modify, extend, alternate behavior of the framework.

mikeedjones commented 1 month ago

Yes - I actually think this is one of the few cases, out of the many "I want to tweak this behavior" issues which are open, where you can inject some custom code into the guts of dspy by overwriting the model_validate_json method.

BlueKiji77 commented 1 month ago

As opposed to looking for backticks maybe a more flexible approach would be looking for the longest valid json string? Can also implement any of these in TypedPredictor by overriding the model_validate_json class method - as that looks like what gets used to parse the model output.

Something like the below:

from typing import Type, Any
from pydantic import BaseModel, ValidationError
import json

# no point parsing a string which isn't longer than the total length of all the
# keys in the model
def get_min_length(model: Type[BaseModel]):
    min_length = 0
    for key, field in model.model_fields.items():
        if issubclass(field.annotation, BaseModel):
            min_length+=get_min_length(field.annotation)
        min_length += len(key)
    return min_length

class Address(BaseModel):
    street: str
    city: str
    state: str
    zip: str

class Company(BaseModel):
    name: str
    address: Address

class User(BaseModel):
    name: str
    email: str
    company: Company
    home_address: Address

    @classmethod
    def model_validate_json(
        cls,
        json_data: str,
        *,
        strict: bool | None = None,
        context: dict[str, Any] | None = None
    ) -> "User":
        __tracebackhide__ = True
        try:
            return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context)
        except ValidationError:
            min_length = get_min_length(cls)
            for substring_length in range(len(json_string), min_length-1, -1):
                for start in range(len(json_string)-substring_length+1):
                    substring = json_string[start:start+substring_length]
                    try:
                        res = cls.__pydantic_validator__.validate_json(substring, strict=strict, context=context)
                        return res
                    except ValidationError:
                        pass
        raise ValueError("Could not find valid json")

dummy_data = {
    "name": "John Doe",
    "email": "johndoe@example.com",
    "company": {
        "name": "Example Corp",
        "address": {
            "street": "123 Example St",
            "city": "Example City",
            "state": "EX",
            "zip": "12345"
        }
    },
    "home_address": {
        "street": "456 Home St",
        "city": "Home City",
        "state": "HM",
        "zip": "67890"
    }
}

json_string  = f"""
No problem, here is the json you requested:
{json.dumps(dummy_data, indent=4)}
Hope that helps!
"""

User.model_validate_json(json_string)

I tried this expecting it to work; I am not very experienced at programming. Been flirting with the idea of copying the typed predictor classes into my code and modifying them that way, only the internals of dspy are hard for me to wrap my head around.

Wouldn't having the option to pass a JSON output parser to the predictor be a good solution? The parser would parse the generation of the LLM checking for Output conformity before each retry.

StevenGuo42 commented 1 month ago

The problem with playing with the generated text is that these results may contain more than 1 JSON, only of them is the actual result we are looking for, and the longest one is also not what we want. Example:

Output: {"mapping":{"apple":"fruit","banana":"fruit","tomato":"vegetable","cabbage":"vegetable","human":null}}. Respond with a single JSON object. JSON Schema: {"properties": {"mapping": {"description": "The mapping from the first variable to the second variable", "title": "Mapping", "type": "object"}}, "required": ["mapping"], "title": "Output", "type": "object"}

Also, I believe using the models as intended by using the specific template they were trained on, and generating valid results on the first try using libraries for structured output would be the better approach.

juicesharp commented 1 month ago

The problem with playing with the generated text is that these results may contain more than 1 JSON, only of them is the actual result we are looking for, and the longest one is also not what we want. Example:

Output: {"mapping":{"apple":"fruit","banana":"fruit","tomato":"vegetable","cabbage":"vegetable","human":null}}. Respond with a single JSON object. JSON Schema: {"properties": {"mapping": {"description": "The mapping from the first variable to the second variable", "title": "Mapping", "type": "object"}}, "required": ["mapping"], "title": "Output", "type": "object"}

Also, I believe using the models as intended by using the specific template they were trained on, and generating valid results on the first try using libraries for structured output would be the better approach.

This changes nothing. You may extract all of them and all pieces will be validated against correspondent pydantic schemas .... still better then try to enforce an additional LLM call. Even if this approach worked for a single Output field cases only this would cover 99% of the cases .... with around 50%+ of success ...

P.S. We would not discuss this at all if framework provided a clear way to inject a custom "transformer" function like I described above post_process_prediction(prediction: Prediction) -> Prediction that would allow you transform prediction before it was rejected by embedded validation or goes on another round to beg LLM format somehow better ...

mikeedjones commented 1 month ago

It looks like @thomasahle in https://github.com/stanfordnlp/dspy/pull/451 hardcoded the function below into typed predictor which wraps the completion before passing it to pydantic's model_validate_json. https://github.com/stanfordnlp/dspy/blob/9c1fff94bb48b55a2ddaab5ce8c23f7ad3af61e3/dspy/functional/functional.py#L262

def _unwrap_json(output):
    output = output.strip()
    if output.startswith("```"):
        if not output.startswith("```json"):
            raise ValueError("json output should start with ```json")
        if not output.endswith("```"):
            raise ValueError("Don't write anything after the final json ```")
        output = output[7:-3].strip()
    if not output.startswith("{") or not output.endswith("}"):
        raise ValueError("json output should start and end with { and }")
    return ujson.dumps(ujson.loads(output))  # ujson is a bit more robust than the standard json

The errors raised in that function stop you from doing anything smart with pydantic's model_validate_json.

A change which could work would be to pass model_validate_json to _unwrap_json and check if the un-altered output can be parsed before attempting to clean up the completion? Basically:

def _unwrap_json(output, native_parser):
    try:
        return native_parser(output)
    except (ValueError, ValidationError) as exc:
        output = output.strip()
        if output.startswith("```"):
            if not output.startswith("```json"):
                raise ValueError("json output should start with ```json")
            if not output.endswith("```"):
                raise ValueError("Don't write anything after the final json ```")
            output = output[7:-3].strip()
        if not output.startswith("{") or not output.endswith("}"):
            raise ValueError("json output should start and end with { and }")
        return ujson.dumps(ujson.loads(output))  # ujson is a bit more robust than the standard json
jamesschinnerplxs commented 1 month ago

A failure mode I encountered is truncated json due to a token limit, in some cases it may still be useful to try and parse this without completely failing, Potentially being able to specify a custom parser (could be useful for other serialization formats also), would allow the user to try and handle partial JSON results?