noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.01k stars 46 forks source link

Bug in generating objects in union of types. #53

Closed LorrinWWW closed 6 months ago

LorrinWWW commented 6 months ago

Here is an minimal example:

schemas = {
 'title': 'Functions',
 'type': 'array',
 'items': {'anyOf': [{'$ref': '#/definitions/apple'},
   {'$ref': '#/definitions/banana'}]},
 'definitions': {
     'apple': {
         'type': 'object',
         'properties': {'name': {'enum': ['apple'], 'type': 'string'}},
         'required': ['name']
     },
     'banana': {
         'type': 'object',
         'properties': {'name': {'enum': ['apple'], 'type': 'string'}},
         'required': ['name']
     }
 }
}

parser = JsonSchemaParser(schemas)
prefix_function = build_transformers_prefix_allowed_tokens_fn(pipe.tokenizer, parser)

prompt = """Please generate a list of fruits."""

# Call the pipeline with the prefix function
output_dict = pipe(prompt, max_new_tokens=100, prefix_allowed_tokens_fn=prefix_function)

It generates:

Please generate a list of fruits.

[
  {
    "name" 

name": "apple"

 "apple"

 },
  {
    "name" 

 name": "apple"

which does not look like a valid JSON.

LorrinWWW commented 6 months ago

@noamgat Any advice on walking around this issue? Thank you so much!

LorrinWWW commented 6 months ago

Should be irrelevant to array. Here is another example:

schemas = {
 'title': 'Functions',
 'type': 'array',
 'items': {'anyOf': [{'$ref': '#/definitions/apple'},
   {'$ref': '#/definitions/banana'}]},
 'definitions': {
     'apple': {
         'type': 'object',
         'properties': {'name': {'enum': ['apple'], 'type': 'string'}},
         'required': ['name']
     },
     'banana': {
         'type': 'object',
         'properties': {'name': {'enum': ['apple'], 'type': 'string'}},
         'required': ['name']
     }
 }
}

from typing import List, Union
from pydantic import BaseModel
import json

class Apple(BaseModel):
    name: str
    size: int

class Banana(BaseModel):
    name: str
    number: int

class Fruit(BaseModel):
    fruit: Apple|Banana

parser = JsonSchemaParser(Fruit.schema())
prefix_function = build_transformers_prefix_allowed_tokens_fn(pipe.tokenizer, parser)

# Call the pipeline with the prefix function
output_dict = pipe("Please generate a fruit:", max_new_tokens=100, prefix_allowed_tokens_fn=prefix_function)

generates:

Please generate a fruit:

{
  "fruit": {
    "name"

 name": "apple"

 "},  "

,"number"

 size

with a error log:

ERROR:root:Unknown LMFormatEnforcer Problem. Prefix: '

{
  "fruit": {
    "name"

 name": "apple"

 "},  "

 ,"number"

 size'
Terminating the parser. Please open an issue at 
https://github.com/noamgat/lm-format-enforcer/issues with the prefix and CharacterLevelParser parameters
Traceback (most recent call last):
  File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/tokenenforcer.py", line 81, in _compute_allowed_tokens
    self._collect_allowed_tokens(state.parser, self.tokenizer_tree.root, allowed_tokens, shortcut_key)
  File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/tokenenforcer.py", line 122, in _collect_allowed_tokens
    self._collect_allowed_tokens(next_parser, next_tree_node, allowed_tokens, None)
  File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/tokenenforcer.py", line 120, in _collect_allowed_tokens
    next_parser = parser.add_character(character )
  File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/jsonschemaparser.py", line 69, in add_character
    updated_parser.object_stack[receiving_idx] = updated_parser.object_stack[receiving_idx].add_character(new_character)
  File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/characterlevelparser.py", line 89, in add_character
    next_parsers = [parser.add_character(new_character) for parser in relevant_parsers]
  File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/characterlevelparser.py", line 89, in <listcomp>
    next_parsers = [parser.add_character(new_character) for parser in relevant_parsers]
  File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/jsonschemaparser.py", line 270, in add_character
    value_schema = self.schema_object.properties[self.current_key]
KeyError: 'size'
LorrinWWW commented 6 months ago

It seems that, when updating a union char parser, it will append multiple key string parsers for its child parsers (https://github.com/noamgat/lm-format-enforcer/blob/1ee9d39de93e630e2cd5c041ebac3c8352853296/lmformatenforcer/jsonschemaparser.py#L279C39-L279C39). So there will be multiple key string parser in the object_stack. And thus, those key string parsers are trigger multiple times, causing multiple keys generated, eventually messing up the parser states.

@noamgat Do you think it's a fundamental issue or not?

noamgat commented 6 months ago

Thanks for the report! I was able to create a failing unit test based on it, and will investigate and solve.

noamgat commented 6 months ago

I pushed a fix to the problem. Can you test the branch via

pip install git+https://github.com/noamgat/lm-format-enforcer.git@feature/union_typed_arrays

And check if you get better results?

LorrinWWW commented 6 months ago

@noamgat Awesome! It works well on my side. Thank you so much for your quick response!

noamgat commented 6 months ago

Released in 0.8.2