noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.01k stars 46 forks source link

JsonSchemaParser allows invalid \' escapes and stops on them #41

Closed elonen closed 6 months ago

elonen commented 6 months ago

For some reason, my model sometimes tries to escape ' characters inside JSON strings, which is invalid. JsonSchemaParser allows this, but seems to stop generation immediately when it happens:

Example:

### Instruction:
You are an obedient assistant.

### Input:
How can you write "I am" shorter? Answer in a JSON variant where single quotes are escaped (\').

### Response:

Answer, With json schema enforcing:
{ "answer": "I\'
Answer, Without json schema enforcing:
{ "short": "\'I\'m\'" }

...generated with:

from lmformatenforcer import JsonSchemaParser
import pydantic

class AnswerFormat(pydantic.BaseModel):
    answer: str

question = r'How can you write "I am" shorter? Answer in a JSON variant where single quotes are escaped (\').'
prompt = get_prompt(question)

print("Prompt:")
print(prompt)

print("Answer, With json schema enforcing:")
result = vllm_with_character_level_parser(llm, prompt, JsonSchemaParser(AnswerFormat.schema()))
print(result)

print("Answer, Without json schema enforcing:")
result = vllm_with_character_level_parser(llm, prompt, None)
print(result)
elonen commented 6 months ago

Tried to debug, with limited success. Findings based on my debug prints below:

=> Maybe this is not a bug in JsonSchemaParser after all, but somewhere even higher, like prefix matching?

(...)
? ObjectParsingState can_end: False
? JsonSchemaParser can_end: False
  -> ALL allowed: 3299 characters (parsed_string:|I|)
? StringParsingState can_end: False  (parsed_string:|I|)
### get_allowed_characters: 3299 characters
  -> ALL allowed: 3299 characters (parsed_string:|I|)
add_character: \  (parsed_string:|I|)
  -> pushing json_escaping_parser (parsed_string:|I\|)
     - allowed: nu/"r\bft
### add_character: \  (last_parsed_string:|answer|). New parser: <lmformatenforcer.jsonschemaparser.JsonSchemaParser object at 0x7f030cfd0c10>
### get_allowed_characters: nu/"r\bft

LLM RESULT: { "answer": "I\'

**Analyzer Results:**
  generated_token  generated_token_idx  generated_score leading_token  leading_token_idx  leading_score
0               {                28751          0.75782             {              28751        0.75782
1               "                  345          0.67517             "                345        0.67517
2          answer                24115          0.01738         short              10046        0.41792
3              ":                 1264          0.92942            ":               1264        0.92942
4               "                  345          0.44408             "                345        0.44408
5               I                28737          0.77531             I              28737        0.77531
6              \'                12919          0.87704            \'              12919        0.87704
7            </s>                    2          0.00000             m              28719        0.97276
noamgat commented 6 months ago

Thanks for the bug report. \' should not be allowed by the json schema parser. The fact that the json schema parser allowed it in timestep 6 is a bug. This seems to be enough for a unit test reproduction, I will investigate and fix.

noamgat commented 6 months ago

Thanks for the bug report. The problem was a bug in the performance optimization json_freetext_tokens (its a deviation from the clean solution, but avoids a big performance hit). It allowed tokens of character sequences that could not be part of a legal json. Can you verify that installing from main

pip install git+https://github.com/noamgat/lm-format-enforcer.git@main

solves the issue?

elonen commented 6 months ago

Yes 👍 Leading token was \', but enforcer overrode it correctly:

LLM RESULT: |{ "answer": "I'm" }|
**Analyzer Results:**

   generated_token  generated_token_idx  generated_score leading_token  leading_token_idx  leading_score
0                {                28751          0.75782             {              28751        0.75782
1                "                  345          0.67517             "                345        0.67517
2           answer                24115          0.01738         short              10046        0.41792
3               ":                 1264          0.92942            ":               1264        0.92942
4                "                  345          0.44408             "                345        0.44408
5                I                28737          0.77531             I              28737        0.77531
6                '                28742          0.07784            \'              12919        0.87704
7                m                28719          0.79426             m              28719        0.79426
8                "                28739          0.91152             "              28739        0.91152
9                }                  443          0.99984             }                443        0.99984
10            </s>                    2          0.94004          </s>                  2        0.94004