noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.01k stars 46 forks source link

Limited length JSON schema string is 30x slower than plain str #42

Closed elonen closed 6 months ago

elonen commented 6 months ago

Here's benchmark of the same text generation on Pydantic with str and constr fields:

Prompt: ### Instruction:
You are an obedient assistant.

### Input:
Write a knock-knock joke in JSON format. The joke should have a preamble, a question, a name, a who, and a punchline, in that order.

### Response:

Answer, with plain str:
{
  "preamble": "Knock knock!",
  "question": "Who's there?",
  "name": "Banana",
  "who": "Banana who?",
  "punchline": "Banana-na-na-na-na-na-na-na-na-na-na!"
}
Answer, with constr:
{
  "preamble": "Knock knock!",
  "question": "Who's there?",
  "name": "Banana",
  "who": "Banana who?",
  "punchline": "Banana-na-na-na-na-na-na-na-na-na-na!"
}
Plain: 2.49 s, Constr: 72.87 s

Relevant code:

question = r'Write a knock-knock joke in JSON format. The joke should have a preamble, a question, a name, a who, and a punchline, in that order.'
prompt = get_prompt(question)
print("Prompt:", prompt)

def plain_str():
    class AnswerFormat1(pydantic.BaseModel):
        preamble: str
        question: str
        name: str
        who: str
        punchline: str

    print("Answer, with plain str:")
    result = vllm_with_character_level_parser(llm, prompt, JsonSchemaParser(AnswerFormat1.schema()))
    print(result)

def with_constr():
    class AnswerFormat2(pydantic.BaseModel):
        preamble: pydantic.constr(max_length=100) # type: ignore
        question: pydantic.constr(max_length=100) # type: ignore
        name: pydantic.constr(max_length=100) # type: ignore
        who: pydantic.constr(max_length=100) # type: ignore
        punchline: pydantic.constr(max_length=100) # type: ignore

    print("Answer, with constr:")
    result = vllm_with_character_level_parser(llm, prompt, JsonSchemaParser(AnswerFormat2.schema()))
    print(result)

plain_time = timeit.timeit('plain_str()', number=1, globals=globals())
constr_time = timeit.timeit('with_constr()', number=1, globals=globals())

print(f"Plain: {plain_time:.2f} s, Constr: {constr_time:.2f} s")
noamgat commented 6 months ago

Yes, length-constrained strings can not take advantage of the "json freetext token" shortcut so they use the full mode. It is possible to develop a length-specific caching to make it faster, I'll leave this issue open for people to vote on it to see how much demand there is for it.

elonen commented 6 months ago

~Hmm, actually, this seems to work:~

             # print("Filtering whitespace characters")
             allowed_characters = "".join(c for c in allowed_characters if c not in WHITESPACE_CHARACTERS)
@@ -102,11 +102,12 @@ class JsonSchemaParser(CharacterLevelParser):
             current_parser = self.object_stack[-1]
             if isinstance(current_parser, StringParsingState):
                 if not current_parser.allowed_strings and current_parser.seen_opening_quote and not current_parser.seen_closing_quote \
-                    and current_parser.min_length is None and current_parser.max_length is None:
+                    and current_parser.min_length is None:
                     # Performance optimization: When we are parsing a string that is not from a list of allowed strings, most tokens
                     # are legal. The exploration can be more costly than the LM itself for large tokenizers (because this is pure python),
                     # so we signal that we are in a "freetext" mode, and reuse the allowed token list throughout the run.
-                    return 'json_freetext'
+                    if current_parser.max_length is None or len(current_parser.parsed_string) < current_parser.max_length:
+                        return 'json_freetext'
         return None
Answer, with plain str:
{
  "preamble": "Knock knock!",
  "question": "Who's there?",
  "name": "Banana",
  "who": "Banana who?",
  "punchline": "Banana-na-na-na-na-na-na-na-na-na-na!"
}
Answer, with constr:
{
  "preamble": "Knock knock!",
  "question": "Who's there?",
  "name": "Banana",
  "who": "Banana who?",
  "punchline": "Banana-na-na-na-"     (<-- constr max 16 chars)
}
Plain: 2.50 s, Constr: 2.15 s

~Is there something that will break despite it looking ok in this test?~

EDIT: Yes, something breaks. It cuts short in the first line with constr 8 chars:

Answer, with constr:
{
  "preamble": "Knock knock

Apparently the first test with 16 chars just happened to sum up to exactly 16 with the tokens freetext shortcut lead to.