noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.42k stars 65 forks source link

Invalid JSONs with spurious comma generated occasionally #80

Closed pinoloricato closed 7 months ago

pinoloricato commented 7 months ago

Thanks for working on this project.

I have version 0.9.0 installed. The issue in the title occurs with multiple versions of llama-cpp-python up until the latest (0.2.52). The following should reproduce it:

import json
from typing import List

from huggingface_hub import hf_hub_download
from llama_cpp import Llama, LogitsProcessorList
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.llamacpp import (
    build_llamacpp_logits_processor, build_token_enforcer_tokenizer_data)
from pydantic import BaseModel

lm = Llama(
    model_path=hf_hub_download(
        repo_id="TheBloke/tinyllama-1.1b-chat-v1.0-GGUF",
        filename="tinyllama-1.1b-chat-v1.0.Q2_K.gguf",
    ),
    verbose=False,
)
tokenizer_data = build_token_enforcer_tokenizer_data(lm)

class FlightRoute(BaseModel):
    airports: List[str]
    cost_of_flight: int

schema = FlightRoute.model_json_schema()
logits_processors = LogitsProcessorList(
    [build_llamacpp_logits_processor(tokenizer_data, JsonSchemaParser(schema))]
)

for i in range(20):
    response = lm("", logits_processor=logits_processors, max_tokens=None)
    if isinstance(response, dict):
        output = response["choices"][0]["text"]
        print(i, output)
        json.loads(output)

The loop eventually generates an output like

 {   
        "airports": [
           ,
           "name",
           "city",
           "country",
           "coordinates",
           "latitude",
           "longitude"]
noamgat commented 7 months ago

Thanks for the report! Reproduced and fixed