outlines-dev / outlines

Structured Text Generation
https://outlines-dev.github.io/outlines/
Apache License 2.0
7.07k stars 364 forks source link

`outlines.generate.choice` generates tkens other than provided choices - special tokens being added to tokenizer incorrectly? #893

Closed aaronsnoswell closed 3 weeks ago

aaronsnoswell commented 1 month ago

Describe the issue as clearly as possible:

With some models, outlines.generate.choice is leading to answers being generated which aren't one of the choices provided to outlines.generate.choice. This seems to only occur for some models, and when this issue occurs, I also see a warning from HF transformers;

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

An MWE is attached below;

Steps/code to reproduce the bug:

import torch
import outlines
from outlines import samplers

rng = torch.Generator(device="cuda")
rng.manual_seed(1337)

# Generated outputs match the provided choices
#model_path = "distilbert/distilgpt2"

# Generated outputs are not in the set of chocies
# Also get a warning ''Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
#model_path = "meta-llama/Meta-Llama-3-8B"
model_path = "EleutherAI/pythia-1b-deduped"

model = outlines.models.transformers(model_path, device="cuda")
model.model.half()

sampler = samplers.multinomial(1)
generator = outlines.generate.choice(model, ["-1", "0", "1"], sampler)

prompt = """Give me an integer ranging from -1 to 1 inclusive..."""

for i in range(10):
    answer = generator(prompt, rng=rng)
    print(answer)

Expected result:

# Something like the following
-1
0
0
-1
1
1
0
1
-1
1

Error message:

# The actual generated output varies based on the model, but e.g. with `EleutherAI/pythia-1b-deduped`, I get;
+/
.
.
+/
/
/
+/
+/
/
+/

Outlines/Python version information:

(brix) C:\Development>python -c "from outlines import _version; print(_version.version)" 0.0.41

(brix) C:\Development>python -c "import sys; print('Python', sys.version)" Python 3.12.2 | packaged by Anaconda, Inc. | (main, Feb 27 2024, 17:28:07) [MSC v.1916 64 bit (AMD64)]

(brix) C:\Development>pip freeze accelerate==0.30.1 aiohttp==3.9.5 aiosignal==1.3.1 anaconda-anon-usage @ file:///C:/b/abs_c3w_h1zzjg/croot/anaconda-anon-usage_1710965204622/work annotated-types==0.6.0 anyio==4.3.0 archspec @ file:///croot/archspec_1709217642129/work argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asttokens==2.4.1 async-lru==2.0.4 attrs==23.2.0 autoflake==2.3.1 Babel==2.15.0 beautifulsoup4==4.12.3 black==24.4.2 bleach==6.1.0 blinker==1.8.2 boltons @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/boltons_1699480450092/work Brotli @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/brotli-split_1699473013692/work certifi @ file:///C:/b/abs_35d7n66oz9/croot/certifi_1707229248467/work/certifi cffi @ file:///C:/b/abs_924gv1kxzj/croot/cffi_1700254355075/work charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work click==8.1.7 cloudpickle==3.0.0 colorama @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/colorama_1699472650914/work comm==0.2.2 conda @ file:///C:/b/abs_1e6dlkntna/croot/conda_1710772093015/work conda-content-trust @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/conda-content-trust_1699553484152/work conda-libmamba-solver @ file:///croot/conda-libmamba-solver_1706733287605/work/src conda-package-handling @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/conda-package-handling_1699480603217/work conda_package_streaming @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/conda-package-streaming_1699475879769/work cryptography @ file:///C:/b/abs_f5n93r0tun/croot/cryptography_1710350404202/work datasets==2.19.1 debugpy==1.8.1 decorator==5.1.1 defusedxml==0.7.1 dill==0.3.8 diskcache==5.6.3 distro @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/distro_1701796812765/work dnspython==2.6.1 email_validator==2.1.1 executing==2.0.1 Faker==25.1.0 fakeredis==2.23.0 fastapi==0.111.0 fastapi-cli==0.0.3 fastjsonschema==2.19.1 filelock==3.14.0 Flask==3.0.3 Flask-Cors==4.0.1 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2024.3.1 h11==0.14.0 httpcore==1.0.5 httptools==0.6.1 httpx==0.27.0 huggingface-hub==0.23.0 idna @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/idna_1699473483982/work iniconfig==2.0.0 intel-openmp==2021.4.0 interegular==0.3.3 ipykernel==6.29.4 ipython==8.24.0 ipywidgets==8.1.2 isoduration==20.11.0 isort==5.13.2 itsdangerous==2.2.0 jedi==0.19.1 Jinja2==3.1.4 joblib==1.4.2 json5==0.9.25 jsonpatch @ file:///C:/b/abs_d3zr1enxou/croot/jsonpatch_1710807549298/work jsonpointer==2.1 jsonschema==4.22.0 jsonschema-specifications==2023.12.1 jupyter-events==0.10.0 jupyter-lsp==2.2.5 jupyter_client==8.6.1 jupyter_core==5.7.2 jupyter_server==2.14.0 jupyter_server_terminals==0.5.3 jupyterlab==4.2.0 jupyterlab_pygments==0.3.0 jupyterlab_server==2.27.1 jupyterlab_widgets==3.0.10 lark==1.1.9 libmambapy @ file:///C:/b/abs_7dmjutgtwb/croot/mamba-split_1712091963973/work/libmambapy llvmlite==0.42.0 markdown-it-py==3.0.0 MarkupSafe==2.1.5 matplotlib-inline==0.1.7 mdurl==0.1.2 menuinst @ file:///C:/b/abs_099kybla52/croot/menuinst_1706732987063/work mistune==3.0.2 mkl==2021.4.0 mpmath==1.3.0 multidict==6.0.5 multiprocess==0.70.16 mypy-extensions==1.0.0 nbclient==0.10.0 nbconvert==7.16.4 nbformat==5.10.4 nest-asyncio==1.6.0 networkx==3.2.1 nltk==3.8.1 notebook_shim==0.2.4 numba==0.59.1 numpy==1.26.4 openai==1.28.1 orjson==3.10.3 outlines==0.0.41 overrides==7.7.0 packaging @ file:///C:/b/abs_cc1h2xfosn/croot/packaging_1710807447479/work pandas==2.2.2 pandocfilters==1.5.1 parso==0.8.4 pathspec==0.12.1 pillow==10.2.0 platformdirs @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/platformdirs_1701797392447/work pluggy==1.5.0 prometheus_client==0.20.0 prompt-toolkit==3.0.43 psutil==5.9.8 pure-eval==0.2.2 pyarrow==16.0.0 pyarrow-hotfix==0.6 pycosat @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/pycosat_1699482932804/work pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work pydantic==2.7.1 pydantic_core==2.18.2 pyflakes==3.2.0 Pygments==2.18.0 PySocks @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/pysocks_1699473336188/work pytest==8.2.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 python-json-logger==2.0.7 python-multipart==0.0.9 pytz==2024.1 pywin32==306 pywinpty==2.0.13 PyYAML==6.0.1 pyzmq==26.0.3 ranking_challenge==1.0.3 redis==5.0.4 referencing==0.35.1 regex==2024.5.10 requests @ file:///C:/b/abs_474vaa3x9e/croot/requests_1707355619957/work rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==13.7.1 rpds-py==0.18.1 ruamel.yaml @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/ruamel.yaml_1699483184324/work safetensors==0.4.3 scikit-learn==1.4.2 scipy==1.13.0 Send2Trash==1.8.3 setuptools==68.2.2 shellingham==1.5.4 six==1.16.0 sniffio==1.3.1 sortedcontainers==2.4.0 soupsieve==2.5 stack-data==0.6.3 starlette==0.37.2 sympy==1.12 tbb==2021.11.0 terminado==0.18.1 threadpoolctl==3.5.0 tinycss2==1.3.0 tokenizers==0.19.1 torch==2.3.0+cu118 torchaudio==2.3.0+cu118 torchvision==0.18.0+cu118 tornado==6.4 tqdm @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/tqdm_1701808178601/work traitlets==5.14.3 transformers==4.40.2 truststore @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/truststore_1701881385424/work typer==0.12.3 types-python-dateutil==2.9.0.20240316 typing_extensions==4.11.0 tzdata==2024.1 ujson==5.9.0 uri-template==1.3.0 urllib3 @ file:///C:/b/abs_4etpfrkumr/croot/urllib3_1707770616184/work uvicorn==0.29.0 watchfiles==0.21.0 wcwidth==0.2.13 webcolors==1.13 webencodings==0.5.1 websocket-client==1.8.0 websockets==12.0 Werkzeug==3.0.3 wheel==0.41.2 widgetsnbextension==4.0.10 win-inet-pton @ file:///C:/Users/dev-admin/perseverance-python-buildout/croot/win_inet_pton_1699472992992/work xxhash==3.4.1 yarl==1.9.4 zstandard==0.19.0

Context for the issue:

No response

aaronsnoswell commented 1 month ago

At the suggestion of folks in the discord, I tried cloning the main branch and using that instead of my pip install outlines.

Can confirm the bug still occurs there.

aaronsnoswell commented 1 month ago

(as in, this commit; https://github.com/outlines-dev/outlines/commit/78852b0169e7c4c6f3eaf6b2b2e6209e41edf98c)

isamu-isozaki commented 1 month ago

I tried your code in the main branch using

pip uninstall outlines
pip install git+https://github.com/outlines-dev/outlines.git@main

and I got

-1
0
0
1
-1
-1
1
1
1
-1
isamu-isozaki commented 1 month ago
!python -c "from outlines import _version; print(_version.version)"
0.0.43.dev11+g78852b0
!python -c "import sys; print('Python', sys.version)"
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:40:08) [MSC v.1938 64 bit (AMD64)]
!pip freeze
accelerate==0.29.3
aiohttp==3.9.5
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.3.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.2.0
auto_gptq==0.7.1
Babel==2.14.0
beautifulsoup4==4.12.3
bleach==6.1.0
Brotli @ file:///D:/bld/brotli-split_1695989908365/work
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi
cffi==1.16.0
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
comm==0.2.2
cramjam==2.8.3
dataclasses-json==0.6.4
datasets==2.19.0
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.8
diskcache==5.6.3
distro==1.9.0
exceptiongroup==1.2.1
executing==2.0.1
exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.0.20/exllamav2-0.0.20+cu118-cp310-cp310-win_amd64.whl#sha256=5545d7ff9e31c0e7fb8667b36ac55c28c89c396438b9b7be287777ad33b9a157
fastapi==0.110.2
fastjsonschema==2.19.1
fastparquet==2024.2.0
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1712686151958/work
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.3.1
gekko==1.1.1
greenlet==3.0.3
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.23.0
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1713279365350/work
intel-openmp==2021.4.0
interegular==0.3.3
ipykernel==6.29.4
ipython==8.24.0
ipywidgets==8.1.2
isoduration==20.11.0
jedi==0.19.1
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1704966972576/work
joblib==1.4.2
json5==0.9.25
jsonpatch==1.33
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.14.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.8
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.1
jupyterlab_widgets==3.0.10
langchain==0.1.16
langchain-community==0.0.34
langchain-core==0.1.46
langchain-text-splitters==0.0.1
langsmith==0.1.51
lark==1.1.9
llvmlite==0.42.0
MarkupSafe @ file:///D:/bld/markupsafe_1706900062361/work
marshmallow==3.21.1
matplotlib-inline==0.1.7
mistune==3.0.2
mkl==2021.4.0
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work
multidict==6.0.5
multiprocess==0.70.16
mypy-extensions==1.0.0
nbclient==0.10.0
nbconvert==7.16.3
nbformat==5.10.4
nest-asyncio==1.6.0
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1712540363324/work
ninja==1.11.1.1
notebook==7.1.3
notebook_shim==0.2.4
numba==0.59.1
numpy @ file:///D:/bld/numpy_1707225570061/work/dist/numpy-1.26.4-cp310-cp310-win_amd64.whl#sha256=6761da75b1528684e6bf4dabdbdded9d1eb4d0e9b299482c7ce152cfb3155106
openai==1.23.6
orjson==3.10.1
outlines @ git+https://github.com/outlines-dev/outlines.git@78852b0169e7c4c6f3eaf6b2b2e6209e41edf98c
overrides==7.7.0
packaging==23.2
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.4
peft==0.10.0
pillow @ file:///D:/bld/pillow_1712154657455/work
platformdirs==4.2.1
prometheus_client==0.20.0
prompt-toolkit==3.0.43
psutil==5.9.8
pure-eval==0.2.2
pyairports==2.1.1
pyarrow==16.0.0
pyarrow-hotfix==0.6
pycountry==23.12.11
pycparser==2.22
pydantic==2.7.1
pydantic_core==2.18.2
Pygments==2.17.2
PySocks @ file:///D:/bld/pysocks_1661604991356/work
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
pytz==2024.1
pywin32==306
pywinpty==2.0.13
PyYAML @ file:///D:/bld/pyyaml_1695373629531/work
pyzmq==26.0.2
qtconsole==5.5.1
QtPy==2.4.1
referencing==0.35.0
regex==2024.4.16
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1684774241324/work
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rouge==1.0.1
rpds-py==0.18.0
safetensors==0.4.3
scipy==1.13.0
Send2Trash==1.8.3
sentencepiece==0.2.0
six==1.16.0
sniffio==1.3.1
soupsieve==2.5
SQLAlchemy==2.0.29
sseclient==0.0.27
stack-data==0.6.3
starlette==0.37.2
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1684180539862/work
tbb==2021.12.0
tenacity==8.2.3
terminado==0.18.1
tiktoken==0.6.0
tinycss2==1.3.0
tokenizers==0.19.1
tomli==2.0.1
torch==2.3.0
torchaudio==2.3.0
torchvision==0.18.0
tornado==6.4
tqdm==4.66.2
traitlets==5.14.3
transformers @ git+https://github.com/huggingface/transformers@e0c3cee17085914bbe505c159beeb8ae39bc37dd
types-python-dateutil==2.9.0.20240316
typing-inspect==0.9.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1712329955671/work
tzdata==2024.1
uri-template==1.3.0
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1708239446578/work
uvicorn==0.29.0
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.8.0
websockets==12.0
widgetsnbextension==4.0.10
win-inet-pton @ file:///D:/bld/win_inet_pton_1667051142467/work
xxhash==3.4.1
yarl==1.9.4
GurvanR commented 1 month ago

Hello, I have the same issue of wrong token generations. I'm using the vllm serve and to that end I installed with pip install outlines[serve]

It is working well with OPT models but here is what I have with other models: (in parentheses are the tokens it generated, the expected tokens are 'A', 'B', 'C' or 'D'.)

Note that all these models are working well with vLLM.

So my question is probably how can I transfer the trick of installing the main with git+ command with the vllm serve ?

thank you all !

wjn0 commented 1 month ago

I'm still seeing this on 7863f8e8bbaeb71c9d2434636a2d63bfe6dd7d39 with hf-internal-testing/tiny-random-LlamaForCausalLM.

Possibly relevant warning (but it is not resolved by manually setting pad_token_id to e.g. eos_token_id):

UserWarning: `pad_token_id` should be positive but got -1. This will cause errors when batch generating, if there is padding. Please set `pas_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values.

Repro script (modified from above):

import torch
import outlines
from outlines import samplers

rng = torch.Generator()
rng.manual_seed(1337)

# Generated outputs match the provided choices
#model_path = "distilbert/distilgpt2"

# Generated outputs are not in the set of chocies
# Also get a warning ''Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
#model_path = "meta-llama/Meta-Llama-3-8B"
#model_path = "EleutherAI/pythia-1b-deduped"
model_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"

model = outlines.models.transformers(model_path)
# model.model.half()

sampler = samplers.multinomial(1)
generator = outlines.generate.choice(model, [",", "\n"], sampler)

prompt = """Give me an integer ranging from -1 to 1 inclusive..."""

for i in range(10):
    answer = generator(prompt, rng=rng)
    print(answer)

Can further confirm that the above test script, unmodified, works fine with v0.0.39 (regardless of whether pad_token_id is manually set or not).

rlouf commented 1 month ago

Pinging @lapp0

lapp0 commented 1 month ago

I cannot reproduce, here's what I get:

-1
0
0
1
-1
-1
1
1
1
-1

Could you please check what output you get for just outlines.generate.text?

Code

import torch
import outlines
from outlines import samplers

rng = torch.Generator(device="cuda")
rng.manual_seed(1337)

model_path = "EleutherAI/pythia-1b-deduped"

model = outlines.models.transformers(model_path, device="cuda")
model.model.half()

sampler = samplers.multinomial(1)
generator = outlines.generate.text(model)

prompt = """Some numbers: -1, 0, 1, -1, 0, 1, -1, 0, 1,"""

answer = generator(prompt, rng=rng, max_tokens=30)
print(answer)

My outlines.generate.text() Output:

 -1. But we have enough energy sets to cover the moon!" He trotted down the long hallway, his footsteps echoing.
 Est
1
br3no commented 1 month ago

I strongly believe this is an issue with the state-machine cache that was fixed with this PR: https://github.com/outlines-dev/outlines/pull/911

@brandonwillard, what do you think?

aaronsnoswell commented 1 month ago

Would it help for me to pull PR #911 and test at my end?


From: Breno Faria @.> Sent: Wednesday, May 29, 2024 6:32:38 AM To: outlines-dev/outlines @.> Cc: Aaron Snoswell @.>; Author @.> Subject: Re: [outlines-dev/outlines] outlines.generate.choice generates tkens other than provided choices - special tokens being added to tokenizer incorrectly? (Issue #893)

I strongly believe this is an issue with the state-machine cache that was fixed with this PR: #911https://urldefense.com/v3/__https://github.com/outlines-dev/outlines/pull/911__;!!NVzLfOphnbDXSw!CAnaZZY7cOMopUW1J6sBiorekrcQP-Q8zw9J3_mVA4J7I8hDTp107rirbnj2ogU5xUb-__ZwQPnfsPjUZAuqE1bTP5UL$

@brandonwillardhttps://urldefense.com/v3/__https://github.com/brandonwillard__;!!NVzLfOphnbDXSw!CAnaZZY7cOMopUW1J6sBiorekrcQP-Q8zw9J3_mVA4J7I8hDTp107rirbnj2ogU5xUb-__ZwQPnfsPjUZAuqExWmRYZa$, what do you think?

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https://github.com/outlines-dev/outlines/issues/893*issuecomment-2136057121__;Iw!!NVzLfOphnbDXSw!CAnaZZY7cOMopUW1J6sBiorekrcQP-Q8zw9J3_mVA4J7I8hDTp107rirbnj2ogU5xUb-__ZwQPnfsPjUZAuqEx_k8rV_$, or unsubscribehttps://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AAHMIWYME4IRAVNIWV6PJYLZETSWNAVCNFSM6AAAAABHZHRT66VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMZWGA2TOMJSGE__;!!NVzLfOphnbDXSw!CAnaZZY7cOMopUW1J6sBiorekrcQP-Q8zw9J3_mVA4J7I8hDTp107rirbnj2ogU5xUb-__ZwQPnfsPjUZAuqE6KZSyi6$. You are receiving this because you authored the thread.Message ID: @.***>

brandonwillard commented 1 month ago

Would it help for me to pull PR #911 and test at my end?

It's already merged into main, so you can check that out and try it.