epfl-dlab / transformers-CFG

🤗 A specialized library for integrating context-free grammars (CFG) in EBNF with the Hugging Face Transformers
http://saibo-creator.xyz:7860/
MIT License
86 stars 15 forks source link

Feature Request: allow same `GrammarConstrainedLogitsProcessor` to be reused across multiple generations #49

Open Saibo-creator opened 4 months ago

Saibo-creator commented 4 months ago

Reproduce

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model_id = "mistralai/Mistral-7B-v0.1"

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, device_map="auto")
    model.generation_config.pad_token_id = model.generation_config.eos_token_id

    grammar_str = """
    # Grammar for subset of JSON
    # String doesn't support unicode and escape yet
    # If you don't need to generate unicode and escape, you can use this grammar
    # We are working to support unicode and escape

    root   ::= object

    object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}"

    value  ::= object | array | string | number | ("true" | "false" | "null") ws

    array  ::= "[" ws ( value ("," ws value)* )? "]" ws

    string ::= "\"" [ \t!#-\[\]-~]* "\"" ws

    number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

    ws ::= ([ \t\n] ws)?
    """
    grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
    grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

    # Generate
    prefix1 = "This is a valid json string for http request:"
    prefix2 = "This is a valid json string for shopping cart:"

    for prefix in [prefix1, prefix2]:
        input_ids = tokenizer(
            [prefix], add_special_tokens=False, return_tensors="pt", padding=True
        )["input_ids"]

        output = model.generate(
            input_ids,
            do_sample=False,
            max_new_tokens=60,
            logits_processor=[grammar_processor],
            repetition_penalty=1.1,
            num_return_sequences=1,
        )
        # decode output
        generations = tokenizer.batch_decode(output, skip_special_tokens=True)
        print(generations)

        """
        'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }}
        'This is a valid json string for shopping cart:This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 }
        """

Error message

Traceback (most recent call last):
  File "/home/saibo/Dev/SGCD-new/scripts/reproduce_tcfg_bug1.py", line 54, in <module>
    output = model.generate(
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers/generation/utils.py", line 1736, in generate
    result = self._sample(
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers/generation/utils.py", line 2388, in _sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 98, in __call__
    scores = processor(input_ids, scores)
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers_cfg/generation/logits_process.py", line 106, in __call__
    return self.process_logits(input_ids, scores)
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers_cfg/generation/logits_process.py", line 93, in process_logits
    self.batch_accept_states = self.grammar_constraint.consume_token_ids(
  File "/home/saibo/.virtualenvs/python3.10/SGCD-new/lib/python3.10/site-packages/transformers_cfg/token_grammar_recognizer.py", line 211, in consume_token_ids
    raise RuntimeError(
RuntimeError: Input ID's length is inconsistent with the current state of the GrammarConstrainedLogitsProcessor. If you want to process another input sequence, please instantiate a new GrammarConstrainedLogitsProcessor.
nathanrchn commented 4 weeks ago

Very strange. I ran this and got a very strange but still valid JSON.

import torch
import numpy as np
import mlx.core as mx
from transformers import AutoTokenizer
from mlx_lm import load, stream_generate
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

model, _ = load("mlx-community/Phi-3.5-mini-instruct-4bit")

tokenizer = AutoTokenizer.from_pretrained("mlx-community/Phi-3.5-mini-instruct-4bit")
tokenizer.pad_token = tokenizer.eos_token   

with open("examples/grammars/json.ebnf", "r") as f:
    grammar_str = f.read()

grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

def logits_processor(input_ids: mx.array, logits: mx.array) -> mx.array:
    torch_input_ids = torch.tensor(np.array(input_ids[None, :]), device="mps")
    torch_logits = torch.tensor(np.array(logits), device="mps")

    torch_processed_logits = grammar_processor(torch_input_ids, torch_logits)
    return mx.array(torch_processed_logits.cpu().numpy())

prefix1 = "This is a valid json string for http request:"
prefix2 = "This is a valid json string for shopping cart:"

for prefix in [prefix2, prefix1]:
    generation_stream = stream_generate(
        model,
        tokenizer,
        prompt=prefix,
        max_tokens=500,
        repetition_penalty=1.1,
        logits_processor=logits_processor
    )

    print("\033[92m" + "Prompt:" + prefix + "\033[0m")

    for token in generation_stream:
        print(token, end="", flush=True)

    print()
    grammar_processor.reset()

Console output:

transformers-CFG % python3 debug.py
Fetching 11 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 45680.54it/s]
Prompt:This is a valid json string for shopping cart:
{"items":[
   {"item":"milk","quantity":2,"price":{"$type":"NumberInt", "value":3},
   "item":"bread","quantity":1,"price":{"$type":"NumberInt", "value":2},
   "item":"eggs","quantity":1,"price":{"$type":"NumberInt", "value":5}},
   {"item":"flour","quantity":2,"price":{"$type":"NumberInt", "value":1}}]
}
Prompt:This is a valid json string for http request:
{"name":"John","age":30,"city":"New York"}

When using a beautify website to validate the JSON, I got:

{
  "items": [
    {
      "item": "eggs",
      "quantity": 1,
      "price": {
        "$type": "NumberInt",
        "value": 5
      }
    },
    {
      "item": "flour",
      "quantity": 2,
      "price": {
        "$type": "NumberInt",
        "value": 1
      }
    }
  ]
}

I think is very strange that the model didn't generate the following JSON:

{"items":[
   {"item":"milk","quantity":2,"price":{"$type":"NumberInt", "value":3}},
   {"item":"bread","quantity":1,"price":{"$type":"NumberInt", "value":2}},
   {"item":"eggs","quantity":1,"price":{"$type":"NumberInt", "value":5}},
   {"item":"flour","quantity":2,"price":{"$type":"NumberInt", "value":1}}]
}

Maybe there is a bug here...