Closed LorrinWWW closed 6 months ago
@noamgat Any advice on walking around this issue? Thank you so much!
Should be irrelevant to array. Here is another example:
schemas = {
'title': 'Functions',
'type': 'array',
'items': {'anyOf': [{'$ref': '#/definitions/apple'},
{'$ref': '#/definitions/banana'}]},
'definitions': {
'apple': {
'type': 'object',
'properties': {'name': {'enum': ['apple'], 'type': 'string'}},
'required': ['name']
},
'banana': {
'type': 'object',
'properties': {'name': {'enum': ['apple'], 'type': 'string'}},
'required': ['name']
}
}
}
from typing import List, Union
from pydantic import BaseModel
import json
class Apple(BaseModel):
name: str
size: int
class Banana(BaseModel):
name: str
number: int
class Fruit(BaseModel):
fruit: Apple|Banana
parser = JsonSchemaParser(Fruit.schema())
prefix_function = build_transformers_prefix_allowed_tokens_fn(pipe.tokenizer, parser)
# Call the pipeline with the prefix function
output_dict = pipe("Please generate a fruit:", max_new_tokens=100, prefix_allowed_tokens_fn=prefix_function)
generates:
Please generate a fruit:
{
"fruit": {
"name"
name": "apple"
"}, "
,"number"
size
with a error log:
ERROR:root:Unknown LMFormatEnforcer Problem. Prefix: '
{
"fruit": {
"name"
name": "apple"
"}, "
,"number"
size'
Terminating the parser. Please open an issue at
https://github.com/noamgat/lm-format-enforcer/issues with the prefix and CharacterLevelParser parameters
Traceback (most recent call last):
File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/tokenenforcer.py", line 81, in _compute_allowed_tokens
self._collect_allowed_tokens(state.parser, self.tokenizer_tree.root, allowed_tokens, shortcut_key)
File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/tokenenforcer.py", line 122, in _collect_allowed_tokens
self._collect_allowed_tokens(next_parser, next_tree_node, allowed_tokens, None)
File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/tokenenforcer.py", line 120, in _collect_allowed_tokens
next_parser = parser.add_character(character )
File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/jsonschemaparser.py", line 69, in add_character
updated_parser.object_stack[receiving_idx] = updated_parser.object_stack[receiving_idx].add_character(new_character)
File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/characterlevelparser.py", line 89, in add_character
next_parsers = [parser.add_character(new_character) for parser in relevant_parsers]
File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/characterlevelparser.py", line 89, in <listcomp>
next_parsers = [parser.add_character(new_character) for parser in relevant_parsers]
File "/home/jue@together.xyz/miniconda3/envs/nebula-fav2/lib/python3.10/site-packages/lmformatenforcer/jsonschemaparser.py", line 270, in add_character
value_schema = self.schema_object.properties[self.current_key]
KeyError: 'size'
It seems that, when updating a union char parser, it will append multiple key string parsers for its child parsers (https://github.com/noamgat/lm-format-enforcer/blob/1ee9d39de93e630e2cd5c041ebac3c8352853296/lmformatenforcer/jsonschemaparser.py#L279C39-L279C39).
So there will be multiple key string parser in the object_stack
. And thus, those key string parsers are trigger multiple times, causing multiple keys generated, eventually messing up the parser states.
@noamgat Do you think it's a fundamental issue or not?
Thanks for the report! I was able to create a failing unit test based on it, and will investigate and solve.
I pushed a fix to the problem. Can you test the branch via
pip install git+https://github.com/noamgat/lm-format-enforcer.git@feature/union_typed_arrays
And check if you get better results?
@noamgat Awesome! It works well on my side. Thank you so much for your quick response!
Released in 0.8.2
Here is an minimal example:
It generates:
which does not look like a valid JSON.