epfl-dlab / transformers-CFG

🤗 A specialized library for integrating context-free grammars (CFG) in EBNF with the Hugging Face Transformers
http://saibo-creator.xyz:7860/
MIT License
61 stars 9 forks source link

Generation terminates on escape sequences with llama-3 models #63

Open MatthewChang opened 4 days ago

MatthewChang commented 4 days ago

First of all, thanks for making and maintaining such a useful repo!

One issue I'm finding is that generation stops when the grammar includes a new line character "\n" or other escape sequence when using llama-3 models (specifically llama3-8b).

When running this snippet

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # model_id = "mistralai/Mistral-7B-v0.1"
    model_id = "meta-llama/Meta-Llama-3-8B"

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(model_id).to(
        device
    )  # Load model to defined device
    model.generation_config.pad_token_id = model.generation_config.eos_token_id

    # Load grammar
    with open("examples/grammars/json_arr.ebnf", "r") as file:
        grammar_str = file.read()
    grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
    grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

    # Generate
    prefix1 = "This is a valid json array for student records:"
    prefix2 = "This is a valid json array for shopping cart:"
    input_ids = tokenizer(
        [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True
    )["input_ids"].to(
        device
    )  # Move input_ids to the same device as model

    output = model.generate(
        input_ids,
        do_sample=False,
        max_new_tokens=60,
        logits_processor=[grammar_processor],
        repetition_penalty=1.0,
        num_return_sequences=1,
    )
    # decode output
    generations = tokenizer.batch_decode(output, skip_special_tokens=True)
    print(generations)

Which is just the examples/generate_json_array.py with mistral-7b swapped out with llama-3-8b

The generation I get is ['This is a valid json array for student records:[\n', 'This is a valid json array for shopping cart:[\n’]

The generations here should not be accepted by the grammar. I can reproduce this with a very simple grammar grammar_str = 'root ::= "first\\nsecond”’

This will generate “first\n”

Similarly, replacing the \n with \t or \r will cause the generation to return first\t and first\r respectively. This does not happen with Mistral-7b or llama2-7b.

I can reproduce this on main with a clean conda environment. Package versions after installing requirements below. Thanks for your help!

appdirs==1.4.4 black==21.4b2 certifi==2024.7.4 cffi==1.15.1 cfgv==3.4.0 charset-normalizer==3.3.2 contourpy==1.1.0 cycler==0.11.0 Cython==0.29.36 distlib==0.3.8 easydict==1.10 filelock==3.15.4 fonttools==4.41.0 fsspec==2024.6.1 future==0.18.3 huggingface-hub==0.23.4 identify==2.5.36 idna==3.7 imageio==2.31.1 Jinja2==3.1.4 jsonpointer==2.4 kiwisolver==1.4.4 lazy_loader==0.3 line_profiler==4.1.3 lvis @ git+https://github.com/lvis-dataset/lvis-api.git@da5f65d16237637d848a51713556c48ca521bc18 MarkupSafe==2.1.5 matplotlib==3.7.2 mpmath==1.3.0 networkx==3.1 nodeenv==1.9.1 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.5.82 nvidia-nvtx-cu12==12.1.105 opencv-python==4.8.0.74 packaging==23.1 platformdirs==4.2.2 pre-commit==3.7.1 protobuf==5.27.2 pycparser==2.21 pydot==1.4.2 pyparsing==3.0.9 python-dateutil==2.8.2 PyWavelets==1.4.1 PyYAML==6.0.1 regex==2024.5.15 requests==2.32.3 safetensors==0.4.3 scikit-image==0.21.0 scipy==1.11.1 sentencepiece==0.2.0 Shapely==1.7.1 six==1.16.0 sympy==1.12.1 tifffile==2023.7.10 tokenizers==0.19.1 toml==0.10.2 torch==2.3.1 tornado==6.3.2 transformers==4.42.3 triton==2.3.1 typing_extensions==4.12.2

MatthewChang commented 3 days ago

Some additional info. Generation is terminating because the branch here https://github.com/epfl-dlab/transformers-CFG/blob/5f3772588bd2424eb27451bd6e400b7286630044/transformers_cfg/token_grammar_recognizer.py#L102 is getting taken (i.e. stacks are empty). I'm guessing there is some difference with the tokenization in the llama-3 models that is causing this. I can try to fix it myself if you have any pointers about how I might track down the issue.

Saibo-creator commented 2 days ago

Hey @MatthewChang, thanks for poining out this behavior with llama-3 and providing detailed information. Since llama-3 uses a new tokenizer, something may be broken. I'll look into this and keep you updated :)