jxnl / instructor

structured outputs for llms
https://python.useinstructor.com/
MIT License
7.5k stars 595 forks source link

Literal while streaming does not work #879

Closed ahuang11 closed 1 week ago

ahuang11 commented 1 month ago

What Model are you using?

Describe the bug

While streaming a Literal, it crashes with a validation error:

ValidationError: 1 validation error for PartialUserInfo
name
  Input should be 'Bob', 'Alice' or 'John' [type=literal_error, input_value='', input_type=str]
    For further information visit https://errors.pydantic.dev/2.8/v/literal_error

To Reproduce Steps to reproduce the behavior, including code snippets of the model and the input data and openai response.

import instructor
from typing import Literal
from instructor import Partial
from pydantic import BaseModel
from openai import AsyncOpenAI

# Define your desired output structure
class UserInfo(BaseModel):
    name: Literal["Bob", "Alice", "John"]
    age: int

# Patch the OpenAI client
client = instructor.from_openai(AsyncOpenAI())

# Extract structured data from natural language
user_info = await client.chat.completions.create(
    model="gpt-3.5-turbo",
    response_model=Partial[UserInfo],
    messages=[{"role": "user", "content": "John Doe is 30 years old."}],
    stream=True
)

async for chunk in user_info:
    print(chunk)

Expected behavior A clear and concise description of what you expected to happen.

Validation should happen once the stream is finished.

Related: https://github.com/jxnl/instructor/pull/362

Screenshots If applicable, add screenshots to help explain your problem.

jxnl commented 1 month ago

whats the ideal check? i dont think theres a good answer unles you want to add a custom before validator

roeybc commented 1 month ago

I might be missing something, but during the json parsing in the PartialBase class, from_json is called with partial_mode="trailing-strings". if we change it to "on" instead, incomplete strings will be discarded. something like this:

obj = from_json(
    (potential_object or "{}").encode(), partial_mode="on"
)

I'll be happy to tackle it if I didn't miss anything :D

demux79 commented 2 weeks ago

@roeybc That would be very helpful, thanks. I am facing the same issue at the moment.

jxnl commented 2 weeks ago

I might be missing something, but during the json parsing in the PartialBase class, from_json is called with partial_mode="trailing-strings". if we change it to "on" instead, incomplete strings will be discarded. something like this:

obj = from_json(
    (potential_object or "{}").encode(), partial_mode="on"
)

I'll be happy to tackle it if I didn't miss anything :D

this would be a great PR cc @ivanleomk

roeybc commented 2 weeks ago

Awesome, on it 😀

ivanleomk commented 2 weeks ago

Just saw a PR @roeybc , lemme go test it out. Thanks for the submission!

ivanleomk commented 2 weeks ago

Hmm, ok this is a good solution but it also breaks our ability to stream partial string fields. Is this a trade off we want to make @jxnl?

import instructor
from pydantic import BaseModel
from openai import AsyncOpenAI
import asyncio

# Define your desired output structure
class UserInfo(BaseModel):
    name: str
    age: int

# Patch the OpenAI client
client = instructor.from_openai(AsyncOpenAI())

async def generate():
    # Extract structured data from natural language
    user_info = client.chat.completions.create_partial(
        model="gpt-3.5-turbo",
        response_model=UserInfo,
        messages=[{"role": "user", "content": "John Doe is 30 years old."}],
        stream=True,
    )

    async for chunk in user_info:
        print(chunk)

if __name__ == "__main__":
    asyncio.run(generate())

We can see this when the completion below is being generated.

name=None age=None
name=None age=None
name=None age=None
name=None age=None
name=None age=None
name=None age=None
name='John Doe' age=None
name='John Doe' age=None
name='John Doe' age=None
name='John Doe' age=30
name='John Doe' age=30

Any suggestions for getting around this @roeybc ?

demux79 commented 2 weeks ago

Fair point. Maybe set it to 'On' automatically if Literals are in the class? For my own use I created a monkey patch in case I have literals in my model.

from typing import Literal
from jiter import from_json
from instructor.dsl.partial import PartialBase

def monkey_patch_instructor_partial(
    partial_mode: Literal["off", "on", "trailing-strings"] = "trailing-strings"
):

    def patched_model_from_chunks(cls, json_chunks, **kwargs):
        potential_object = ""
        partial_model = cls.get_partial_model()
        for chunk in json_chunks:
            potential_object += chunk
            obj = from_json(
                (potential_object or "{}").encode(),
                partial_mode=partial_mode,  # Change "trailing-strings" to "on" for Literal support
            )
            obj = partial_model.model_validate(obj, strict=None, **kwargs)
            yield obj

    async def patched_model_from_chunks_async(cls, json_chunks, **kwargs):
        potential_object = ""
        partial_model = cls.get_partial_model()
        async for chunk in json_chunks:
            potential_object += chunk
            obj = from_json(
                (potential_object or "{}").encode(),
                partial_mode=partial_mode,  # Change "trailing-strings" to "on" for Literal support
            )
            obj = partial_model.model_validate(obj, strict=None, **kwargs)
            yield obj

    # Replace the original methods with the patched ones
    PartialBase.model_from_chunks = classmethod(patched_model_from_chunks)
    PartialBase.model_from_chunks_async = classmethod(patched_model_from_chunks_async)

def enable_instructor_literal_patch():
    monkey_patch_instructor_partial(partial_mode="on")

def disable_instructor_literal_patch():
    monkey_patch_instructor_partial(partial_mode="trailing-strings")