outlines-dev / outlines

Structured Text Generation
https://outlines-dev.github.io/outlines/
Apache License 2.0
7.43k stars 379 forks source link

Vectorized calls to OpenAI models are failing #746

Open gserapio opened 4 months ago

gserapio commented 4 months ago

Describe the issue as clearly as possible:

One of the nice features of using Outlines with HF models is the ability to vectorize calls to models rather than making calls serially.

Passing a list prompts to an OpenAI model via outlines.generate.choice() results in a TypeError. The responses from OpenAI API are converted into a numpy array before being encoded by tiktoken, which throws the error. Perhaps it has something to do with that?

The issue doesn't occur when a single prompt is passed to OpenAI as a string.

My reprex uses the latest version of Outlines, 0.0.36, but I've also encountered this issue in previous versions.

Steps/code to reproduce the bug:

import outlines
import os

# prompts, choices
prompts = [
    "What is the closest color to Indigo? ",
    "What is the closest color to red? ",
    "What is the closest color to green? "
]
choices = ["Blue", "Red", "Green"]

# create model instance
model = outlines.models.openai(
    "gpt-3.5-turbo",
    api_key=os.getenv("OPENAI_API_KEY")
)

# set sampler
sampler = outlines.samplers.multinomial()

generator = outlines.generate.choice(model, choices, sampler=sampler)

# serial call works
model_answers = generator("What is the closest color to Indigo? ")
print(model_answers)

# vectorized call results in TypeError
model_answers = generator(prompts)
print(model_answers)

Expected result:

Blue
['Blue', 'Red', 'Green']

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 2
      1 # vectorized call results in TypeError
----> 2 model_answers = generator(prompts)
      3 print(model_answers)

File ~/opt/anaconda3/envs/myenv/lib/python3.10/site-packages/outlines/generate/choice.py:34, in choice_openai.<locals>.generate_choice(prompt, max_tokens)
     33 def generate_choice(prompt: str, max_tokens: int = 1):
---> 34     return model.generate_choice(prompt, choices, max_tokens)

File ~/opt/anaconda3/envs/myenv/lib/python3.10/site-packages/outlines/models/openai.py:220, in OpenAI.generate_choice(self, prompt, choices, max_tokens, system_prompt)
    217 self.prompt_tokens += prompt_tokens
    218 self.completion_tokens += completion_tokens
--> 220 encoded_response = self.tokenizer.encode(response)
    222 if encoded_response in encoded_choices_left:
    223     decoded.append(response)

File ~/opt/anaconda3/envs/myenv/lib/python3.10/site-packages/tiktoken/core.py:116, in Encoding.encode(self, text, allowed_special, disallowed_special)
    114     if not isinstance(disallowed_special, frozenset):
    115         disallowed_special = frozenset(disallowed_special)
--> 116     if match := _special_token_regex(disallowed_special).search(text):
    117         raise_disallowed_special_token(match.group())
    119 # https://github.com/PyO3/pyo3/pull/3632

TypeError: cannot use a string pattern on a bytes-like object

Outlines/Python version information:

Version information

``` 0.0.36 Python 3.10.13 (main, Sep 11 2023, 08:16:02) [Clang 14.0.6 ] accelerate==0.26.1 aiohttp==3.9.1 aiosignal==1.3.1 annotated-types==0.6.0 anyio @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a17a7759g2/croot/anyio_1706220182417/work appnope @ file:///Users/ktietz/ci_310/appnope_1643965056645/work argon2-cffi @ file:///opt/conda/conda-bld/argon2-cffi_1645000214183/work argon2-cffi-bindings @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/croot-wbf5edig/argon2-cffi-bindings_1644845754377/work arrow==1.3.0 asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work astunparse==1.6.3 async-lru @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_02efro5ps8/croot/async-lru_1699554529181/work async-timeout==4.0.3 attrs @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_224434dqzl/croot/attrs_1695717839274/work Babel @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_00k1rl2pus/croot/babel_1671781944131/work beautifulsoup4 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fa78jvo_0n/croot/beautifulsoup4-split_1681493044306/work bleach @ file:///opt/conda/conda-bld/bleach_1641577558959/work Bottleneck @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2bxpizxa3c/croot/bottleneck_1707864819812/work Brotli @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_38mvgltu8c/croots/recipe/brotli-split_1659616064542/work cachetools==5.3.2 certifi @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_3bzbkiv4h_/croot/certifi_1707229182618/work/certifi cffi @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_b4nang6w_y/croot/cffi_1700254307954/work chardet @ file:///Users/ktietz/ci_310/chardet_1643965356347/work charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work click==8.1.7 cloudpickle==3.0.0 comm @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b19kb7be6_/croot/comm_1671231124262/work contourpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_041uwyxdzo/croot/contourpy_1700583585236/work cryptography @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b1p0q5vizk/croot/cryptography_1702070293829/work ctransformers==0.2.27 cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work dacite==1.8.1 datasets==2.16.1 debugpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_563_nwtkoc/croot/debugpy_1690905063850/work decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work dill==0.3.7 diskcache==5.6.3 distro==1.9.0 evaluate==0.4.1 exceptiongroup @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b2258scr33/croot/exceptiongroup_1706031391815/work executing @ file:///opt/conda/conda-bld/executing_1646925071911/work fairscale==0.4.13 fastapi==0.109.0 fastjsonschema @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_43a0jaiddu/croots/recipe/python-fastjsonschema_1661368628129/work filelock==3.13.1 fire==0.5.0 fonttools==4.25.0 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2023.10.0 gguf==0.6.0 gptcache==0.1.43 guidance==0.1.10 h11==0.14.0 httpcore==1.0.2 httpx==0.26.0 huggingface-hub==0.20.2 idna @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_771olrhiqw/croot/idna_1666125579282/work interegular==0.3.3 ipykernel @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_98tee4lcge/croot/ipykernel_1691121640975/work ipython @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a1tmxj9b4u/croot/ipython_1704833016119/work ipywidgets==8.1.1 isoduration==20.11.0 jedi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/croot-f1t6hma6/jedi_1644315882177/work Jinja2 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_9fjgzv9ant/croot/jinja2_1666908141308/work joblib==1.3.2 json5 @ file:///tmp/build/80754af9/json5_1624432770122/work jsonpointer==2.4 jsonschema @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_27o3go8sqa/croot/jsonschema_1699041627313/work jsonschema-specifications @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d38pclgu95/croot/jsonschema-specifications_1699032390832/work jupyter-events @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_14ldd9s4d0/croot/jupyter_events_1699282481406/work jupyter-lsp @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_ae9br5v37x/croot/jupyter-lsp-meta_1699978259353/work jupyter-resource-usage==1.0.1 jupyter_client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_58w2siozyz/croot/jupyter_client_1699455907045/work jupyter_core @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_782yoyc_98/croot/jupyter_core_1698937318631/work jupyter_server @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_281pz9vly5/croot/jupyter_server_1699466465530/work jupyter_server_terminals @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e7ryd60iuw/croot/jupyter_server_terminals_1686870731283/work jupyterlab @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e2r14w4wga/croot/jupyterlab_1706802597734/work jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work jupyterlab-widgets==3.0.9 jupyterlab_server @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_315a64u22w/croot/jupyterlab_server_1699555438434/work kiwisolver @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_93o8te804v/croot/kiwisolver_1672387163224/work lark==1.1.9 -e git+https://github.com/facebookresearch/llama.git@ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2#egg=llama llama_cpp_python==0.2.53 llvmlite==0.42.0 lmql==0.7.3 MarkupSafe @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a84ni4pci8/croot/markupsafe_1704206002077/work matplotlib @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_21m9ylm_7k/croot/matplotlib-suite_1698692123710/work matplotlib-inline @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f6fdc0hldi/croots/recipe/matplotlib-inline_1662014472341/work mistune @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_17ya6k1sbs/croots/recipe/mistune_1661496228719/work mpmath==1.3.0 msal==1.26.0 multidict==6.0.4 multiprocess==0.70.15 munkres==1.1.4 nbclient @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_626hpwnurm/croot/nbclient_1698934218848/work nbconvert @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_caxv2emy33/croot/nbconvert_1699022756174/work nbformat @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_cbnf5nccgk/croot/nbformat_1694616744196/work nest-asyncio @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_6b_e0dr4lw/croot/nest-asyncio_1672387130036/work networkx==3.2.1 notebook @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d3ves7gv_b/croot/notebook_1700582112788/work notebook_shim @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d6_ze10f45/croot/notebook-shim_1699455897525/work numba==0.59.0 numexpr @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_45yefq0kt6/croot/numexpr_1696515289183/work numpy==1.24.4 openai==1.13.3 ordered-set==4.1.0 outlines==0.0.36 overrides @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_70s80guh9g/croot/overrides_1699371144462/work packaging @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_6dm6d4jd_t/croot/packaging_1693575176524/work pandas==2.1.4 pandocfilters @ file:///opt/conda/conda-bld/pandocfilters_1643405455980/work parso @ file:///opt/conda/conda-bld/parso_1641458642106/work pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work pillow @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e02b4k5qik/croot/pillow_1707233036487/work platformdirs @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a8u4fy8k9o/croot/platformdirs_1692205661656/work prometheus-client @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_25sgeyk0j5/croots/recipe/prometheus_client_1659455103277/work prompt-toolkit @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_c63v4kqjzr/croot/prompt-toolkit_1704404354115/work protobuf==4.25.2 psutil @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1310b568-21f4-4cb0-b0e3-2f3d31e39728k9coaga5/croots/recipe/psutil_1656431280844/work ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work py-cpuinfo==9.0.0 pyarrow==14.0.2 pyarrow-hotfix==0.6 pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work pydantic==2.5.3 pydantic-settings==2.1.0 pydantic_core==2.14.6 pydot==2.0.0 pyformlang==1.0.4 Pygments @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_29bs9f_dh9/croot/pygments_1684279974747/work PyJWT==2.8.0 pyOpenSSL @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_b8whqav6qm/croot/pyopenssl_1690223428943/work pyparsing @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_3b_3vxnd07/croots/recipe/pyparsing_1661452540919/work PySocks @ file:///Users/ktietz/ci_310/pysocks_1643961536721/work python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work python-dotenv==1.0.0 python-json-logger @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_c3baq2ko4j/croot/python-json-logger_1683823815343/work pytictoc==1.5.3 pytz @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_6btwyyj8a1/croot/pytz_1695131592184/work PyYAML @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a8_sdgulmz/croot/pyyaml_1698096054705/work pyzmq @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_23n9bfwjq5/croot/pyzmq_1686601381911/work referencing @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_5cz64gsx70/croot/referencing_1699012046031/work regex==2023.12.25 requests @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_54zi68h2nb/croot/requests_1690400233316/work responses==0.18.0 rfc3339-validator @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_76ae5cu30h/croot/rfc3339-validator_1683077051957/work rfc3986-validator @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d0l5zd97kt/croot/rfc3986-validator_1683058998431/work rpds-py @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f8jkozoefm/croot/rpds-py_1698945944860/work rpy2==3.5.15 safetensors==0.4.1 scipy==1.12.0 seaborn @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2ci8qzbdyk/croot/seaborn_1673479197351/work Send2Trash @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5b31f0zzlv/croot/send2trash_1699371144121/work sentencepiece==0.1.99 six @ file:///tmp/build/80754af9/six_1644875935023/work sniffio @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1573pknjrg/croot/sniffio_1705431298885/work soupsieve @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_9798xzs_03/croot/soupsieve_1696347567192/work sse-starlette==1.8.2 stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work starlette==0.35.1 starlette-context==0.3.6 sympy==1.12 termcolor==2.4.0 terminado @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fcfvyc0an2/croot/terminado_1671751835701/work tiktoken==0.6.0 tinycss2 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_fcw5_i306t/croot/tinycss2_1668168825117/work tokenizers==0.15.0 tomli @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d0e5ffbf-5cf1-45be-8693-c5dff8108a2awhthtjlq/croots/recipe/tomli_1657175508477/work torch==2.1.2 tornado @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_3a5nrn2jeh/croot/tornado_1696936974091/work tqdm==4.66.1 traitlets @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_6301rd5qbe/croot/traitlets_1671143894285/work transformers==4.36.2 types-python-dateutil==2.8.19.20240106 typing_extensions @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1fdywrbp_3/croot/typing_extensions_1690297474455/work tzdata @ file:///croot/python-tzdata_1690578112552/work tzlocal==5.2 uri-template==1.3.0 urllib3 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_068obtb882/croot/urllib3_1698257558009/work uvicorn==0.25.0 wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work webcolors==1.13 webencodings==0.5.1 websocket-client @ file:///Users/ktietz/ci_310/websocket-client_1643972661291/work wget==3.2 widgetsnbextension==4.0.9 xxhash==3.4.1 yarl==1.9.4 ```

Context for the issue:

The current workaround for this issue is to serially send prompts to your model generator in a for loop, but it would be nice to get vectorized calls to work so that OpenAI and transformers model objects can be used interchangeably in bulk inference pipelines.

gserapio commented 4 months ago

CC: @stephenfitz

rlouf commented 4 months ago

Thank you for opening an issue! Is there any way you could run this in a debugger and tell me what is passed to self.tokenizer.encode? I suspect a Numpy array, but I would like to be sure.

gserapio commented 4 months ago

Thanks for all of your work on this resource! Yes, it looks like response is a numpy array when it's passed to self.tokenizer.encode. Here's the debugger output:

> ~/opt/anaconda3/envs/myenv/lib/python3.10/site-packages/outlines/models/openai.py(220)generate_choice()->None
-> encoded_response = self.tokenizer.encode(response)
(Pdb) l
215                     prompt, system_prompt, self.client, config
216                 )
217                 self.prompt_tokens += prompt_tokens
218                 self.completion_tokens += completion_tokens
219     
220  ->             encoded_response = self.tokenizer.encode(response)
221     
222                 if encoded_response in encoded_choices_left:
223                     decoded.append(response)
224                     break
225                 else:
(Pdb) p response
array(['Blue', 'Red', 'Blue'], dtype='<U4')
(Pdb) whatis response
<class 'numpy.ndarray'>