Open MatthewChang opened 4 days ago
Some additional info. Generation is terminating because the branch here https://github.com/epfl-dlab/transformers-CFG/blob/5f3772588bd2424eb27451bd6e400b7286630044/transformers_cfg/token_grammar_recognizer.py#L102 is getting taken (i.e. stacks are empty). I'm guessing there is some difference with the tokenization in the llama-3 models that is causing this. I can try to fix it myself if you have any pointers about how I might track down the issue.
Hey @MatthewChang, thanks for poining out this behavior with llama-3 and providing detailed information. Since llama-3 uses a new tokenizer, something may be broken. I'll look into this and keep you updated :)
First of all, thanks for making and maintaining such a useful repo!
One issue I'm finding is that generation stops when the grammar includes a new line character "\n" or other escape sequence when using llama-3 models (specifically llama3-8b).
When running this snippet
Which is just the
examples/generate_json_array.py
with mistral-7b swapped out with llama-3-8bThe generation I get is
['This is a valid json array for student records:[\n', 'This is a valid json array for shopping cart:[\n’]
The generations here should not be accepted by the grammar. I can reproduce this with a very simple grammar
grammar_str = 'root ::= "first\\nsecond”’
This will generate “first\n”
Similarly, replacing the
\n
with\t
or\r
will cause the generation to returnfirst\t
andfirst\r
respectively. This does not happen with Mistral-7b or llama2-7b.I can reproduce this on main with a clean conda environment. Package versions after installing requirements below. Thanks for your help!
appdirs==1.4.4 black==21.4b2 certifi==2024.7.4 cffi==1.15.1 cfgv==3.4.0 charset-normalizer==3.3.2 contourpy==1.1.0 cycler==0.11.0 Cython==0.29.36 distlib==0.3.8 easydict==1.10 filelock==3.15.4 fonttools==4.41.0 fsspec==2024.6.1 future==0.18.3 huggingface-hub==0.23.4 identify==2.5.36 idna==3.7 imageio==2.31.1 Jinja2==3.1.4 jsonpointer==2.4 kiwisolver==1.4.4 lazy_loader==0.3 line_profiler==4.1.3 lvis @ git+https://github.com/lvis-dataset/lvis-api.git@da5f65d16237637d848a51713556c48ca521bc18 MarkupSafe==2.1.5 matplotlib==3.7.2 mpmath==1.3.0 networkx==3.1 nodeenv==1.9.1 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.5.82 nvidia-nvtx-cu12==12.1.105 opencv-python==4.8.0.74 packaging==23.1 platformdirs==4.2.2 pre-commit==3.7.1 protobuf==5.27.2 pycparser==2.21 pydot==1.4.2 pyparsing==3.0.9 python-dateutil==2.8.2 PyWavelets==1.4.1 PyYAML==6.0.1 regex==2024.5.15 requests==2.32.3 safetensors==0.4.3 scikit-image==0.21.0 scipy==1.11.1 sentencepiece==0.2.0 Shapely==1.7.1 six==1.16.0 sympy==1.12.1 tifffile==2023.7.10 tokenizers==0.19.1 toml==0.10.2 torch==2.3.1 tornado==6.3.2 transformers==4.42.3 triton==2.3.1 typing_extensions==4.12.2