noamgat / lm-format-enforcer

Enforce the output format (JSON Schema, Regex etc) of a language model
MIT License
1.42k stars 65 forks source link

Support regexes inside json schema using "pattern" field #32

Closed jpeig closed 8 months ago

jpeig commented 10 months ago

Wondering what your plans are for the future. Perhaps there are further plans to integrate JSON schema and regex? E.g. support the attribute "pattern" inside of the schema:

https://json-schema.org/understanding-json-schema/reference/string

noamgat commented 10 months ago

Hi! The core of the library is working, I plan to support more features that have user demand. Regex inside json string is a great idea! Can you edit this issue to be a feature request for it, so people can vote on it?

jpeig commented 10 months ago

Can't seem to edit it @noamgat. I think you have to allow others to be able to do this.

elonen commented 9 months ago

👍 This would be super useful for URL fields, selecting one of allowed tags, phone numbers etc.

(Sorry, not sure how you wanted enhancement voting to happen @noamgat, I failed to spot any dedicated button for it from the GUI.)

jloganolson commented 9 months ago

+1 re this feature - I would love to limit the string inside certain fields. Are there any workarounds in the meantime (beyond the most primitive -- doing a narrow regex step followed by a 'convert to json' step)

noamgat commented 9 months ago

With the advancements in LMFE in the past few weeks, it is relatively easy to add regex value support to LMFE, with the exception of quoting the " character inside the regex inside a json string field. There is also the concept of regex pattern keys, that would be harder. I'll see if I can get something working in an hour of work. (My main current project is porting the core of LMFE to c++ to get a big performance gain)

noamgat commented 9 months ago

I added support for it in a side branch: https://github.com/noamgat/lm-format-enforcer/tree/feature/json_string_pattern_regex

You can install it by running

pip install git+https://github.com/noamgat/lm-format-enforcer.git@feature/json_string_pattern_regex

Do you want to be the beta testers and check if it works properly?

jloganolson commented 9 months ago

Awesome! Yes, I'll try it out and report back Monday.

jloganolson commented 9 months ago

I tried it out, and I don't know if the output failure is due to user error or the code - here is a summary followed by the full code you can run.

I tried the following pydantic model with and without a simple phone number regex pattern:

class ContactInfo(BaseModel):
    name: str
    phone: str = Field(pattern=r"\([0-9]{3}\)[0-9]{3}-[0-9]{4}")
    #phone: str

The prompt

f"""<s>[INST] Could you convert the following contact information into JSON:
Frank Balone
324.512.8920

Please to the following JSON schema:
{ContactInfo.schema_json()} [/INST]"""

Without regex pattern I get

 {
"name": "Frank Balone",
"phone": "324.512.8920"
}

With regex pattern I get

 {
"name": "Frank Balone",
"phone": "3

(there were no errors in the output)

Here is a full code snippet to run (just add your model path):

from pydantic import BaseModel, Field
from vllm import LLM, SamplingParams
from lmformatenforcer import CharacterLevelParser
from lmformatenforcer.integrations.vllm import (
    build_vllm_logits_processor,
    build_vllm_token_enforcer_tokenizer_data,
)
from lmformatenforcer import JsonSchemaParser
from IPython.display import display, Markdown

model_path = MISTRAL_7B_INSTRUCT_AWQ_PATH
llm = LLM(model=model_path)
tokenizer_data = build_vllm_token_enforcer_tokenizer_data(llm)

class ContactInfo(BaseModel):
    name: str
    # phone: str 
    phone: str = Field(pattern=r"\([0-9]{3}\)[0-9]{3}-[0-9]{4}")

parser = JsonSchemaParser(ContactInfo.schema())
print(ContactInfo.schema())

sampling_params = SamplingParams()
sampling_params.max_tokens = 1024
logits_processor = build_vllm_logits_processor(tokenizer_data, parser)
sampling_params.logits_processors = [logits_processor]

instruction = f"""Could you convert the following contact information into JSON:
Frank Balone
324.512.8920

Please to the following JSON schema:
{ContactInfo.schema_json()}"""
prompt = f"<s>[INST] {instruction} [/INST]"
results = llm.generate(prompt, sampling_params=sampling_params)

display(Markdown(f"```\n{results[0].outputs[0].text}\n```"))
noamgat commented 8 months ago

I think the problem is with the regex - I don't think you need the \ backslashes.

Here is my result (with llama-2-7b-chat, but same code except for the fixed regex):

image

jloganolson commented 8 months ago

If the regex is being used (assuming I'm not bungling it), the area code should be in parentheses, e.g. (818)525-3462

On Mon, Jan 8, 2024, 1:59 PM Noam Gat @.***> wrote:

I think the problem is with the regex - I don't think you need the \ backslashes.

Here is my result (with llama-2-7b-chat, but same code except for the fixed regex):

image.png (view on web) https://github.com/noamgat/lm-format-enforcer/assets/1331304/2b51be91-9b6c-4102-9bbd-190b9e68b723

— Reply to this email directly, view it on GitHub https://github.com/noamgat/lm-format-enforcer/issues/32#issuecomment-1881887415, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADNPOLQYMVALCY7Y57TFSUTYNRT3BAVCNFSM6AAAAABAFYQED2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQOBRHA4DONBRGU . You are receiving this because you commented.Message ID: @.***>

noamgat commented 8 months ago

There were two minor bugs that prevented it from functioning correctly. They were solved, and are in the 0.8.2 release (just released). Please try the latest version and open the issue if problems persist.