dottxt-ai / outlines

Structured Text Generation
https://dottxt-ai.github.io/outlines/
Apache License 2.0
9.69k stars 495 forks source link

`outlines.grammars.json` seems not working #1264

Open syuoni opened 1 week ago

syuoni commented 1 week ago

Describe the issue as clearly as possible:

Using outlines.grammars.json does not make the outputs amenable to json format.

Steps/code to reproduce the bug:

import outlines
from outlines.generate.api import SamplingParameters
from outlines.samplers import greedy

hf_model_dir = "openai-community/gpt2-medium"
model = outlines.models.transformers(hf_model_dir, device="cuda")

prompt = "What is 1+1? The answer in json format: "
generator = outlines.generate.text(model, sampler=greedy())
answer = generator(prompt, max_tokens=10)
print(answer)
# '\xa0"1+1" is a number that'

grammar = outlines.grammars.json
generator = outlines.generate.cfg(model, grammar, sampler=greedy())
answer = generator(prompt, max_tokens=10)
print(answer)
# '0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000'

Expected result:

The second answer should be in json format.

Error message:

No response

Outlines/Python version information:

Version information

``` 0.1.2.dev4+g5f39ded.d20241112 Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] absl-py==2.1.0 accelerate==1.1.1 aiohttp @ file:///rapids/aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72 aiosignal @ file:///rapids/aiosignal-1.3.1-py3-none-any.whl#sha256=f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17 airportsdata==20241001 annotated-types==0.7.0 apex @ file:///opt/pytorch/apex argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 asciitree @ file:///rapids/asciitree-0.3.3-py3-none-any.whl#sha256=b908101e1b4c7103c16c2ca292d9b32bb9761e018b491abf66949b066c75712d asttokens==2.4.1 astunparse==1.6.3 async-timeout @ file:///rapids/async_timeout-4.0.3-py3-none-any.whl#sha256=7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028 attrs==23.2.0 audioread==3.0.1 beautifulsoup4==4.12.3 bleach==6.1.0 blis==0.7.11 cachetools==5.3.3 catalogue==2.0.10 certifi==2024.7.4 cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 cloudpathlib==0.18.1 cloudpickle @ file:///rapids/cloudpickle-3.0.0-py3-none-any.whl#sha256=246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 cmake==3.30.0 comm==0.2.2 confection==0.1.5 contourpy==1.2.1 cuda-python @ file:///rapids/cuda_python-12.5.0-cp310-cp310-linux_x86_64.whl#sha256=3f19392c9f012c2f871bfb936390f1b044d6e845dbfab2bdf809fb5893693d51 cudf @ file:///rapids/cudf-24.4.0-cp310-cp310-linux_x86_64.whl#sha256=629e36e486a35b83732428260c62fb9743022b8da905bcd1c4b94e40c4743a9f cugraph @ file:///rapids/cugraph-24.4.0-cp310-cp310-linux_x86_64.whl#sha256=affd3b54f36b5eb7bdfc24df7f7e6294c30fd836649f0c58e10ff0a757e84472 cugraph-dgl @ file:///rapids/cugraph_dgl-24.4.0-py3-none-any.whl#sha256=7c4f905c5666296b23893ed5db42326419d31aad1af78dcaeae3dc40c43568f5 cugraph-equivariant @ file:///rapids/cugraph_equivariant-24.4.0-py3-none-any.whl#sha256=e810e5dd464b3efd0fc2530581c9249da1f2e9555d3f792508c98d035244cf96 cugraph-pyg @ file:///rapids/cugraph_pyg-24.4.0-py3-none-any.whl#sha256=e145c65b4da809d742621675656433fc44c4e0dec8bca76c48f85fd49171d839 cugraph-service-client @ file:///rapids/cugraph_service_client-24.4.0-py3-none-any.whl#sha256=5ba2194d8550710d04d2f9e4a131088a8309bbf02b859c40d0bc8ffe394118bb cugraph-service-server @ file:///rapids/cugraph_service_server-24.4.0-py3-none-any.whl#sha256=14b80b5f64c37c6e6db5f929b4cc61a5973290937c8f1b28b14e1301becc29e1 cuml @ file:///rapids/cuml-24.4.0-cp310-cp310-linux_x86_64.whl#sha256=0ec1b46fb48e799f0d115cf9507a027bf77572d50f85aa2936d3ca5d953e4ef3 cupy-cuda12x @ file:///rapids/cupy_cuda12x-13.0.0-cp310-cp310-linux_x86_64.whl#sha256=3989978df4bfd8e79c5ce065209afe772b223855c4e0855bfa2b665ee01a3613 cycler==0.12.1 cymem==2.0.8 Cython==3.0.10 dask @ file:///rapids/dask-2024.1.1-py3-none-any.whl#sha256=860ce2797905095beff0187c214840b80c77d752dcb9098a8283e3655a762bf5 dask-cuda @ file:///rapids/dask_cuda-24.4.0-py3-none-any.whl#sha256=0850e358eabbcf2acc4b6f786c2fe348bb0192b9bcff103fccff9f3cba1326a4 dask-cudf @ file:///rapids/dask_cudf-24.4.0-py3-none-any.whl#sha256=047e97f158ab3dc4ba1d5815eb9515ca1bd526739fa70b0baed7912315d34f1d dask-expr @ file:///rapids/dask_expr-0.4.0-py3-none-any.whl#sha256=a2e37fa0fa52afee7ee4822062103bd30820eeedb80771dc7acbeaa6fa2cb92f datasets==3.1.0 debugpy==1.8.2 decorator==5.1.1 defusedxml==0.7.1 dill==0.3.8 diskcache==5.6.3 distributed @ file:///rapids/distributed-2024.1.1-py3-none-any.whl#sha256=cf05d3b38e1700339b3e36395729ab62110e723efefaecc21a8260fdc7555cf9 dm-tree==0.1.8 einops==0.8.0 entrypoints @ file:///rapids/entrypoints-0.4-py3-none-any.whl#sha256=f174b5ff827504fd3cd97cc3f8649f3693f51538c7e4bdf3ef002c8429d42f9f exceptiongroup==1.2.1 execnet==2.1.1 executing==2.0.1 expecttest==0.1.3 fasteners @ file:///rapids/fasteners-0.19-py3-none-any.whl#sha256=758819cb5d94cdedf4e836988b74de396ceacb8e2794d21f82d131fd9ee77237 fastjsonschema==2.20.0 fastrlock @ file:///rapids/fastrlock-0.8.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl#sha256=08315bde19d0c2e6b06593d5a418be3dc8f9b1ee721afa96867b9853fceb45cf filelock==3.15.4 flash-attn==2.4.2 fonttools==4.53.1 frozenlist @ file:///rapids/frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a fsspec @ file:///rapids/fsspec-2024.5.0-py3-none-any.whl#sha256=e0fdbc446d67e182f49a70b82cf7889028a63588fde6b222521f10937b2b670c gast==0.6.0 google-auth==2.32.0 google-auth-oauthlib==0.4.6 grpcio @ file:///rapids/grpcio-1.62.1-cp310-cp310-linux_x86_64.whl#sha256=bd36830a7a76a97b798daabb982f0d86eff5eea3631e3ebbd6e5e9d83d825ccd huggingface-hub==0.26.2 hypothesis==5.35.1 idna==3.7 igraph==0.11.6 importlib_metadata @ file:///rapids/importlib_metadata-7.1.0-py3-none-any.whl#sha256=30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570 iniconfig==2.0.0 intel-openmp==2021.4.0 interegular==0.3.3 ipykernel==6.29.5 ipython==8.21.0 ipython-genutils==0.2.0 jedi==0.19.1 Jinja2==3.1.4 joblib==1.4.2 json5==0.9.25 jsonschema==4.23.0 jsonschema-specifications==2023.12.1 jupyter-tensorboard @ git+https://github.com/cliffwoolley/jupyter_tensorboard.git@ffa7e26138b82549453306e06b535a9ac36db17a jupyter_client==8.6.2 jupyter_core==5.7.2 jupyterlab==2.3.2 jupyterlab-server==1.2.0 jupyterlab_pygments==0.3.0 jupytext==1.16.2 kiwisolver==1.4.5 kvikio @ file:///rapids/kvikio-24.4.0-cp310-cp310-linux_x86_64.whl#sha256=399dadc4a4643ef68f7d89cf93bdae198b2443f7fc38de8f48ec678cf8d7c697 langcodes==3.4.0 language_data==1.2.0 lark==1.2.2 lazy_loader==0.4 librosa==0.10.1 lightning-thunder==0.2.0.dev0 lightning-utilities==0.11.3.post0 lintrunner==0.12.5 llvmlite @ file:///rapids/llvmlite-0.42.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=763f8d8717a9073b9e0246998de89929071d15b47f254c10eef2310b9aac033d locket @ file:///rapids/locket-1.0.0-py2.py3-none-any.whl#sha256=b6c819a722f7b6bd955b80781788e4a66a55628b858d347536b7e81325a3a5e3 looseversion==1.3.0 marisa-trie==1.2.0 Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib==3.9.1 matplotlib-inline==0.1.7 mdit-py-plugins==0.4.1 mdurl==0.1.2 mistune==3.0.2 mkl==2021.1.1 mkl-devel==2021.1.1 mkl-include==2021.1.1 mock==5.1.0 mpmath==1.3.0 msgpack==1.0.8 multidict @ file:///rapids/multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae multiprocess==0.70.16 murmurhash==1.0.10 nbclient==0.10.0 nbconvert==7.16.4 nbformat==5.10.4 nest-asyncio==1.6.0 networkx==3.3 ninja==1.11.1.1 notebook==6.4.10 numba @ file:///rapids/numba-0.59.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl#sha256=525ef3f820931bdae95ee5379c670d5c97289c6520726bc6937a4a7d4230ba24 numcodecs @ file:///rapids/numcodecs-0.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=c27dfca402f69fbfa01c46fb572086e77f38121192160cc8ed1177dc30702c52 numpy==1.24.4 nvfuser==0.2.6a0+f73ff1b nvidia-cudnn-frontend @ file:///opt/pytorch/lightning-thunder/tmp/cudnn_frontend nvidia-dali-cuda120==1.39.0 nvidia-modelopt==0.13.0 nvidia-nvimgcodec-cu12==0.2.0.7 nvidia-pyindex==1.0.9 nvtx @ file:///rapids/nvtx-0.2.5-cp310-cp310-linux_x86_64.whl#sha256=3b2f6d37a391657246ab0d4385ce2346be53678e26a57832215f7eea5959814a nx-cugraph @ file:///rapids/nx_cugraph-24.4.0-py3-none-any.whl#sha256=d2771f2d59bcf03e6d6b25094191d13e0046cee3e974bc4fe7c1bd045101a998 oauthlib==3.2.2 onnx @ file:///opt/pytorch/pytorch/third_party/onnx opencv @ file:///opencv-4.7.0/modules/python/package opt-einsum==3.3.0 optree==0.12.1 -e git+https://github.com/dottxt-ai/outlines@5f39ded4b872d60de563658f1d7a2642de0e8c4e#egg=outlines outlines_core==0.1.14 packaging @ file:///rapids/packaging-24.0-py3-none-any.whl#sha256=2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 pandas @ file:///rapids/pandas-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=c38ce92cb22a4bea4e3929429aa1067a454dcc9c335799af93ba9be21b6beb51 pandocfilters==1.5.1 parso==0.8.4 partd @ file:///rapids/partd-1.4.2-py3-none-any.whl#sha256=978e4ac767ec4ba5b86c6eaa52e5a2a3bc748a2ca839e8cc798f1cc6ce6efb0f pexpect==4.9.0 pillow==10.4.0 platformdirs==4.2.2 pluggy==1.5.0 ply @ file:///rapids/ply-3.11-py2.py3-none-any.whl#sha256=096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce polygraphy==0.49.12 pooch==1.8.2 preshed==3.0.9 prometheus_client==0.20.0 prompt_toolkit==3.0.47 protobuf==4.24.4 psutil @ file:///rapids/psutil-5.9.8-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=d06016f7f8625a1825ba3732081d77c94589dca78b7a3fc072194851e88461a4 ptyprocess==0.7.0 pure-eval==0.2.2 pyarrow==18.0.0 pyasn1==0.6.0 pyasn1_modules==0.4.0 pybind11==2.13.1 pybind11_global==2.13.1 pycocotools @ git+https://github.com/nvidia/cocoapi.git@d99cbf3823588ef09a2721655f46e509ebafb3d7#subdirectory=PythonAPI pycountry==24.6.1 pycparser==2.22 pydantic==2.8.2 pydantic_core==2.20.1 Pygments==2.18.0 pylibcugraph @ file:///rapids/pylibcugraph-24.4.0-cp310-cp310-linux_x86_64.whl#sha256=80f66f9ed55f162973cc1f27e65b3137b5aa36ffcb88b93a81e2a709dc493f01 pylibcugraphops @ file:///rapids/pylibcugraphops-24.4.0-cp310-cp310-linux_x86_64.whl#sha256=cfc7d74921ccdea4d1143419942235d642fffdb593a6ea95520ec199d9b4b53f pylibraft @ file:///rapids/pylibraft-24.4.0-cp310-cp310-linux_x86_64.whl#sha256=8ce45ed554fbbe3b954178c3927b3a569cbd76ea1ae987235527debc28074d54 pylibwholegraph @ file:///rapids/pylibwholegraph-24.4.0-cp310-cp310-linux_x86_64.whl#sha256=c724232c446b8419fce0dcc36fd58cab2e08bb2d54dee1998af1cb2e3fcfee6f pynvjitlink @ file:///rapids/pynvjitlink-0.2.3-cp310-cp310-linux_x86_64.whl#sha256=63e0515584e89c55b31eaa3a598e39fd02e1c390407ad6ec4870fad61c8c491f pynvml @ file:///rapids/pynvml-11.4.1-py3-none-any.whl#sha256=d27be542cd9d06558de18e2deffc8022ccd7355bc7382255d477038e7e424c6c pyparsing==3.1.2 pytest==8.1.1 pytest-flakefinder==1.1.0 pytest-rerunfailures==14.0 pytest-shard==0.1.2 pytest-xdist==3.6.1 python-dateutil==2.9.0.post0 python-hostlist==1.23.0 pytorch-triton @ file:///tmp/dist/pytorch_triton-3.0.0%2B989adb9a2-cp310-cp310-linux_x86_64.whl#sha256=1008f1caa423f84898fde2602fb10f258db5e3f75de2596e518f5c5769a536af pytz @ file:///rapids/pytz-2024.1-py2.py3-none-any.whl#sha256=328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319 PyYAML==6.0.1 pyzmq==26.0.3 raft-dask @ file:///rapids/raft_dask-24.4.0-cp310-cp310-linux_x86_64.whl#sha256=1d6d2869146e689d2502dbeb8ebad5e71cdf1a0e49e0f43d1e0d186b7c5399cd rapids-dask-dependency @ file:///rapids/rapids_dask_dependency-24.4.0a0-py3-none-any.whl#sha256=133645adcb050a33b39168091a0e22112af2e1c8b7e2bf79eb5e84be85138352 referencing==0.35.1 regex==2024.5.15 requests==2.32.3 requests-oauthlib==2.0.0 rich==13.7.1 rmm @ file:///rapids/rmm-24.4.0-cp310-cp310-linux_x86_64.whl#sha256=600bf0eaa5abe064c901c505b62770d79b0fb8c718b4cfac8750821846b3948f rpds-py==0.19.0 rsa==4.9 safetensors==0.4.5 scikit-learn==1.5.1 scipy @ file:///rapids/scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f Send2Trash==1.8.3 shellingham==1.5.4 six==1.16.0 smart-open==7.0.4 sortedcontainers==2.4.0 soundfile==0.12.1 soupsieve==2.5 soxr==0.3.7 spacy==3.7.5 spacy-legacy==3.0.12 spacy-loggers==1.0.5 srsly==2.4.8 stack-data==0.6.3 sympy==1.13.0 tabulate==0.9.0 tbb==2021.13.0 tblib @ file:///rapids/tblib-3.0.0-py3-none-any.whl#sha256=80a6c77e59b55e83911e1e607c649836a69c103963c5f28a46cbeef44acf8129 tensorboard==2.9.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorrt @ file:///workspace/TensorRT-10.2.0.19/python/tensorrt-10.2.0-cp310-none-linux_x86_64.whl#sha256=ddc27586286d416fbb9495cd457430865ae49fcb3fca3bb54db254c75882e12f terminado==0.18.1 texttable==1.7.0 thinc==8.2.5 threadpoolctl==3.5.0 thriftpy2 @ file:///rapids/thriftpy2-0.5.0-cp310-cp310-linux_x86_64.whl#sha256=0194b7a74814889911afc20c71f70f4a1bb21758d8ade66b9daa8feb633a8b4a tinycss2==1.3.0 tokenizers==0.20.3 tomli==2.0.1 toolz @ file:///rapids/toolz-0.12.1-py3-none-any.whl#sha256=d22731364c07d72eea0a0ad45bafb2c2937ab6fd38a3507bf55eae8744aa7d85 torch @ file:///opt/transfer/torch-2.4.0a0%2B3bcc3cddb5.nv24.7-cp310-cp310-linux_x86_64.whl#sha256=4654d8190e1a94e16149e61c144cea52157e469f02f2687e00b50adddf08872b torch-tensorrt @ file:///opt/pytorch/torch_tensorrt/dist/torch_tensorrt-2.5.0a0-cp310-cp310-linux_x86_64.whl#sha256=aa536ce888dd0f0d851518a8fa385123c15bfe8d4689b6a3e67addd56f9cceaa torchvision @ file:///opt/pytorch/vision tornado @ file:///rapids/tornado-6.4-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=f0251554cdd50b4b44362f73ad5ba7126fc5b2c2895cc62b14a1c2d7ea32f212 tqdm==4.66.4 traitlets==5.9.0 transformer-engine @ git+https://github.com/NVIDIA/TransformerEngine.git@37280ecd5e9c6087d18fbe2e668f2ec7761ada3d transformers==4.46.2 treelite @ file:///rapids/treelite-4.1.2-py3-none-linux_x86_64.whl#sha256=0098db86da49955b8b48e8a61603aa781065d6424a28644f51152e7f44513109 typer==0.12.3 types-dataclasses==0.6.6 typing_extensions @ file:///rapids/typing_extensions-4.12.0-py3-none-any.whl#sha256=b349c66bea9016ac22978d800cfff206d5f9816951f12a7d0ec5578b0a819594 tzdata @ file:///rapids/tzdata-2024.1-py2.py3-none-any.whl#sha256=9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252 ucx-py @ file:///rapids/ucx_py-0.37.0-cp310-cp310-linux_x86_64.whl#sha256=2a67bf0d1a593ba5321124971f8fcccbdd101062bffd9f4b64e52ca5be7f5125 urllib3 @ file:///rapids/urllib3-2.0.7-py3-none-any.whl#sha256=fdb6d215c776278489906c2f8916e6e7d4f5a9b602ccbcfdf7f016fc8da0596e wasabi==1.1.3 wcwidth==0.2.13 weasel==0.4.1 webencodings==0.5.1 Werkzeug==3.0.3 wrapt==1.16.0 xdoctest==1.0.2 xgboost @ file:///rapids/xgboost-2.0.3-py3-none-linux_x86_64.whl#sha256=b0e69279b25f839da24687a139c5c0936b465f59d45f13cc8ac3573bf05f5d6d xxhash==3.5.0 yarl @ file:///rapids/yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455 zarr @ file:///rapids/zarr-2.18.2-py3-none-any.whl#sha256=a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38 zict @ file:///rapids/zict-3.0.0-py2.py3-none-any.whl#sha256=5796e36bd0e0cc8cf0fbc1ace6a68912611c1dbd74750a3f3026b9b9d6a327ae zipp @ file:///rapids/zipp-3.19.0-py3-none-any.whl#sha256=96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec ```

Context for the issue:

No response

caleb-kaiser commented 6 days ago

@syuoni I've run into similar problems when using smaller models with the JSON grammar + greedy sampling. Using the JSON grammar defined in https://github.com/dottxt-ai/outlines/blob/main/outlines/grammars/json.lark , we can see that your output string is technically valid:

from lark import Lark

grammar = r"""
?start: value

?value: object
       | array
       | ESCAPED_STRING
       | SIGNED_NUMBER      -> number
       | "true"             -> true
       | "false"            -> false
       | "null"             -> null

array  : "[" [value ("," value)*] ["]"]
object : "{" [pair ("," pair)*] ["}"]
pair   : ESCAPED_STRING ":" value

%import common.ESCAPED_STRING
%import common.SIGNED_NUMBER
%import common.WS

%ignore WS
"""

parser = Lark(grammar, start="start", debug=True)

parser.parse("000000000000000000000000000000")
# Tree('number', [Token('SIGNED_NUMBER', '0000000000000000000000000000')])

Smaller models, in my experience, are more prone to this kind of output, particularly when we only sample the most likely token.

Have you tried the generate.json() method instead? I just tested the following (rolling back to outlines 0.1.3 due to a bug in 0.1.4) with your prompt and it seems to work:

from pydantic import BaseModel

class Answer(BaseModel):
  answer: int

generator = outlines.generate.json(model, Answer)
generator(prompt)
# Answer(answer=1)