Open Saibo-creator opened 4 months ago
Very strange. I ran this and got a very strange but still valid JSON.
import torch
import numpy as np
import mlx.core as mx
from transformers import AutoTokenizer
from mlx_lm import load, stream_generate
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
model, _ = load("mlx-community/Phi-3.5-mini-instruct-4bit")
tokenizer = AutoTokenizer.from_pretrained("mlx-community/Phi-3.5-mini-instruct-4bit")
tokenizer.pad_token = tokenizer.eos_token
with open("examples/grammars/json.ebnf", "r") as f:
grammar_str = f.read()
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
def logits_processor(input_ids: mx.array, logits: mx.array) -> mx.array:
torch_input_ids = torch.tensor(np.array(input_ids[None, :]), device="mps")
torch_logits = torch.tensor(np.array(logits), device="mps")
torch_processed_logits = grammar_processor(torch_input_ids, torch_logits)
return mx.array(torch_processed_logits.cpu().numpy())
prefix1 = "This is a valid json string for http request:"
prefix2 = "This is a valid json string for shopping cart:"
for prefix in [prefix2, prefix1]:
generation_stream = stream_generate(
model,
tokenizer,
prompt=prefix,
max_tokens=500,
repetition_penalty=1.1,
logits_processor=logits_processor
)
print("\033[92m" + "Prompt:" + prefix + "\033[0m")
for token in generation_stream:
print(token, end="", flush=True)
print()
grammar_processor.reset()
Console output:
transformers-CFG % python3 debug.py
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 45680.54it/s]
Prompt:This is a valid json string for shopping cart:
{"items":[
{"item":"milk","quantity":2,"price":{"$type":"NumberInt", "value":3},
"item":"bread","quantity":1,"price":{"$type":"NumberInt", "value":2},
"item":"eggs","quantity":1,"price":{"$type":"NumberInt", "value":5}},
{"item":"flour","quantity":2,"price":{"$type":"NumberInt", "value":1}}]
}
Prompt:This is a valid json string for http request:
{"name":"John","age":30,"city":"New York"}
When using a beautify website to validate the JSON, I got:
{
"items": [
{
"item": "eggs",
"quantity": 1,
"price": {
"$type": "NumberInt",
"value": 5
}
},
{
"item": "flour",
"quantity": 2,
"price": {
"$type": "NumberInt",
"value": 1
}
}
]
}
I think is very strange that the model didn't generate the following JSON:
{"items":[
{"item":"milk","quantity":2,"price":{"$type":"NumberInt", "value":3}},
{"item":"bread","quantity":1,"price":{"$type":"NumberInt", "value":2}},
{"item":"eggs","quantity":1,"price":{"$type":"NumberInt", "value":5}},
{"item":"flour","quantity":2,"price":{"$type":"NumberInt", "value":1}}]
}
Maybe there is a bug here...
Reproduce
Error message