noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.39k stars 61 forks source link

How do I start enforcing schema only after a keyword is hit during inference? #120

Open accupham opened 2 months ago

accupham commented 2 months ago

For example, sometimes during chain of reasoning, it is helpful to have the model reason about the answer in free-form, then output the JSON. So perhaps inference could start without enforcement, then when “‘’’\n{“ is encountered, then start enforcement.

noamgat commented 2 months ago

Interesting. It should be possible to chain several parsers, for example something like (this is pseudocode for the LM Format Enforcer)

SequenceParser
      element 1: Regex Parser <regex for anything except ```\n{>```\n{    # any string that ends with your json prefix
      element 2: JsonSchemaParser (schema).add_character('{')  # adding the { character to start at the json parsing state in which the regex already reached

RegexParser, SequenceParser and JsonSchemaParser are all existing classes. I did not try this, but in theory it should work. However, In reality, I would expect better results in a multiturn scenario:

User: Please do xxxxx. Share your chain of thought reasoning as well.
Assistant: .....
User: Based on the arugments above, output your answer in JSON in the following schema: <schema>
Assistant: <LMFE Json Schema Parser active here>

As in this way, it is mandatory to start the json output at a specific point, where it makes sense conversation wise. In the first scenario, the LLM might want to end the response, and LMFE won't let it (because it didn't output json yet), causing hallucinations.

accupham commented 2 months ago

With regards to the first scenario, what if you did something like this:

SequenceParser:
    element 1:
        UnionParser:
             element 1: RegexParser
             element 2: ForceStopParser
    element 2:
        JSONSchemaParser

I think I would have to modify Sequence parser can_end so it could stop on any, instead of all here: https://github.com/noamgat/lm-format-enforcer/blob/f1dd75b13183c4d6cef85d42f65a5d1e7707973b/lmformatenforcer/characterlevelparser.py#L165

But then now the LLM can end the conversation if it encounters an EOS or stopword before JSON is emitted if the conversation is appropriate, avoiding nasty hallucinations.

I would like to open a PR by adding a new parser that does effectively this with one-shot. Is this the right direction or am I overcomplicating things?

noamgat commented 2 months ago

I'm not sure it warrants a PR, its OK if your code has classes from the LMFE hierarchy. Maybe it could be a sample. Alternatively, create a UnionParser from the two options (one with the json and one without).