huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
8.34k stars 943 forks source link

TGI does not always preserve order of grammar's JSON keys/Pydantic arguments #1956

Open MoritzLaurer opened 1 month ago

MoritzLaurer commented 1 month ago

System Info

Tests run via dedicated endpoints and Idefics2. TGI version was probably 2.0.2

Information

Tasks

Reproduction

The following prompt with grammar returns JSON with keys in a different order than the Pydantic schema. The correct ordering is important for chain-of-thought prompts to work properly.

See this internal discussion for context. The issue seems to come from serialization/deserialization steps throughout the pipeline, which don't enforce ordering.

Reproduce with an idefics2-8b-chatty model on a dedicated endpoint and a grammar:

from pydantic import BaseModel, conint
from typing import List, Literal

# define the structured output you want from the model
class OutputSchema(BaseModel):
    reasoning: str
    contains_diagnosis_diabetes: Literal["Yes", "No"]  # Restrict to "Yes" or "No"

print(OutputSchema.schema())

# simplify the schema so that we can directly pass it into the prompt
def simplify_schema_string(model: BaseModel) -> dict:
    return {prop: details['type'] for prop, details in model.schema()['properties'].items()}

# Print the concise schema
schema_simplified = simplify_schema_string(OutputSchema)
print(schema_simplified)
from transformers import AutoProcessor
import torch

# create the prompt and format it with the chat template
processor = AutoProcessor.from_pretrained(repository)

prompt = f"""\
Your task is to determine if the images contain a diagnosis for diabetes or not. 

Respond with the following JSON scheme: 
- "reasoning": First reason step by step if the images contain a diagnosis for diabetes.  
- "contains_diagnosis_diabetes": After your reasoning, return "Yes" or "No". 

Use the following JSON schema:\n{schema_simplified}
"""

messages = [
    {
        "role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": prompt},
        ]
    },
]

prompt_with_template = processor.apply_chat_template(messages, add_generation_prompt=True)

print(prompt_with_template, "\n")
generation_params = dict(
    # add grammar / JSON schema
    grammar={
        "type": "json",
        "value": OutputSchema.schema(),
    },
    # these parameters further help guide the token generation process
    top_p=0.80,
    top_k=None,
    temperature=0.6,
    repetition_penalty=2.0, 
    do_sample=True,
    max_new_tokens=512,
    return_full_text=False,
    seed=42,
    max_time=None, 
    stream=False,
    details=False,
    use_cache=False,
    wait_for_model=False,
)
import requests

API_URL = endpoint.url
headers = {
    "Accept" : "application/json",
    "Authorization": f"Bearer {huggingface_hub.get_token()}",
    "Content-Type": "application/json" 
}

def query(payload):
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()

output_lst = []
for prompt in prompts_with_images:
    output = query({
        "inputs": prompt,
        "parameters": {
            **generation_params
        }
    })
    output_lst.append(output)

The output provides contains_diagnosis_diabetes first and then the reasoning, which makes CoT useless.

Expected behavior

The grammar should enforce the exact ordering of the JSON keys/Pydantic arguments.

@drbh explained: "Regarding TGI we do have a couple serialization/deserialization before the value is converted to a fsm, so its likely in one of those steps the order is not preserved. JSON by design doesn't guarantee ordering but python dictionaries do preserve order, so if you avoid serializing to JSON the ordering is preserved, unfortunately we rely on sending the grammar as JSON over HTTP and internally over GRPC, therefore ordering is not guaranteed. In order to ensure the ordering we'd need to capture the output regex that to_regex produces and send that as the grammar, or some other regex grammar."

This issue is not urgent but would be relevant to have a solution in the medium term.

MoritzLaurer commented 1 month ago

Did some more testing and can confirm that converting the JSON schema to a regex before passing it to TGI seems to solve the ordering issue:

# convert JSON schema to regex
import json
from outlines.fsm.json_schema import build_regex_from_schema

schema_string = json.dumps(OutputSchema.schema(), indent=2)
schema_regex = build_regex_from_schema(schema_string)
generation_params = dict(
    # add grammar / JSON schema
    grammar={
        "type": "regex", #"json",
        "value": schema_regex, #OutputSchema.schema(),
    },
    # these parameters further help guide the token generation process
    top_p=0.80,
    top_k=None,
    temperature=0.6,
    repetition_penalty=2.0,  # repetition penalty is helpful to avoid that the model gets stuck in generating the same token
    do_sample=True,
    max_new_tokens=512,
    return_full_text=False,
    seed=42,
    max_time=None, 
    stream=False,
    details=False,
    use_cache=False,
    wait_for_model=False,
)
... 
# endpoint API produces output in correct ordering
[{'generated_text': '{"reasoning":"In order to confirm whether there\'s any indication of diabetic conditions in these two documents from KENT BROWN & WILLIAMSON TOBACCO CORP., I would need more specific information about their content and context.",   \n        "contains_diagnosis_diabetes":"No"}'}]
Jacobsolawetz commented 1 month ago

Nice stuff @MoritzLaurer! Hope you're doing great man!

github-actions[bot] commented 5 days ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.