dottxt-ai / outlines

Structured Text Generation
https://dottxt-ai.github.io/outlines/
Apache License 2.0
9.44k stars 479 forks source link

vllm max length breaking structured generation #1173

Closed cpfiffer closed 1 month ago

cpfiffer commented 1 month ago

Describe the issue as clearly as possible:

Using vllm, I experienced an issue where outlines seems to be terminating the output early:

{ "people": [{"name": "Emily Carter",

Suggestions/tips welcome!

Steps/code to reproduce the bug:

# Import the Outlines library
import outlines

model = outlines.models.vllm(
    "Sreenington/Phi-3-mini-4k-instruct-AWQ", 
    quantization="awq",
    max_model_len=1024
)

@outlines.prompt
def userinfo(data):
    """
    <|system|>
    Your role is to review the provided text
    and extract the names, addresses, and phone numbers.

    Please provide a JSON object with the following schema:

    {
        "people": [
            {
                "name": "Cameron",
                "address": "123 Main St, Anytown, USA",
                "phone": "555-1234"
            },
            {
                "name": "Dan",
                "address": "456 Main St, Anytown, USA",
                "phone": "555-5678"
            }
        ]
    }

    <|end|>
    <|user|>
    Here is the text to process: {data }
    <|end|>
    """

data = """
Hey there. My name is Ada Lovelace. I've been having trouble
with a 40-60 feral hogs at my home 1090 Lovelace Manor, Computertown OH.

Could you send someone to help me? You can reach me at
(123) 456-7890.

All the best,
Ada
"""

from pydantic import BaseModel
from typing import List

class Person(BaseModel):
    name: str
    address: str
    phone: str

class UserInfo(BaseModel):
    people: List[Person]

# Run the model
structured_generator = outlines.generate.json(model, UserInfo)
result = structured_generator(
    userinfo(data), 
    seed=31, 
)
print(result)

Expected result:

{
    "people": [
        {
            "name": "Ada Lovelace",
            "address": "1090 Lovelace Manor, Computertown OH",
            "phone": "(123) 456-7890"
        }
    ]
}

Error message:

---------------------------------------------------------------------------
JSONDecodeError                           Traceback (most recent call last)
File ~/dottxt/blog/quarto/intro/.venv/lib/python3.12/site-packages/pydantic/main.py:1187, in BaseModel.parse_raw(cls, b, content_type, encoding, proto, allow_pickle)
   1186 try:
-> 1187     obj = parse.load_str_bytes(
   1188         b,
   1189         proto=proto,
   1190         content_type=content_type,
   1191         encoding=encoding,
   1192         allow_pickle=allow_pickle,
   1193     )
   1194 except (ValueError, TypeError) as exc:

File ~/dottxt/blog/quarto/intro/.venv/lib/python3.12/site-packages/pydantic/deprecated/parse.py:49, in load_str_bytes(b, content_type, encoding, proto, allow_pickle, json_loads)
     48         b = b.decode(encoding)
---> 49     return json_loads(b)  # type: ignore
     50 elif proto == Protocol.pickle:

File /usr/lib/python3.12/json/__init__.py:346, in loads(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)
    343 if (cls is None and object_hook is None and
    344         parse_int is None and parse_float is None and
    345         parse_constant is None and object_pairs_hook is None and not kw):
--> 346     return _default_decoder.decode(s)
    347 if cls is None:

File /usr/lib/python3.12/json/decoder.py:337, in JSONDecoder.decode(self, s, _w)
    333 """Return the Python representation of ``s`` (a ``str`` instance
    334 containing a JSON document).
    335 
    336 """
--> 337 obj, end = self.raw_decode(s, idx=_w(s, 0).end())
    338 end = _w(s, end).end()

File /usr/lib/python3.12/json/decoder.py:353, in JSONDecoder.raw_decode(self, s, idx)
    352 try:
--> 353     obj, end = self.scan_once(s, idx)
    354 except StopIteration as err:

JSONDecodeError: Expecting property name enclosed in double quotes: line 1 column 38 (char 37)

During handling of the above exception, another exception occurred:

ValidationError                           Traceback (most recent call last)
Cell In[6], line 3
      1 # Run the model
      2 structured_generator = outlines.generate.json(model, UserInfo)
----> 3 result = structured_generator(
      4     userinfo(data), 
      5     seed=31, 
      6 )
      7 print(result)

File ~/dottxt/blog/quarto/intro/.venv/lib/python3.12/site-packages/outlines/generate/api.py:511, in SequenceGeneratorAdapter.__call__(self, prompts, max_tokens, stop_at, seed, **model_specific_params)
    499 generation_params = self.prepare_generation_parameters(
    500     max_tokens, stop_at, seed
    501 )
    503 completions = self.model.generate(
    504     prompts,
    505     generation_params,
   (...)
    508     **model_specific_params,
    509 )
--> 511 return format(completions)

File ~/dottxt/blog/quarto/intro/.venv/lib/python3.12/site-packages/outlines/generate/api.py:497, in SequenceGeneratorAdapter.__call__.<locals>.format(sequences)
    495     return [format(sequence) for sequence in sequences]
    496 else:
--> 497     return self.format_sequence(sequences)

File ~/dottxt/blog/quarto/intro/.venv/lib/python3.12/site-packages/outlines/generate/json.py:50, in json.<locals>.<lambda>(x)
     48     regex_str = build_regex_from_schema(schema, whitespace_pattern)
     49     generator = regex(model, regex_str, sampler)
---> 50     generator.format_sequence = lambda x: schema_object.parse_raw(x)
     51 elif callable(schema_object):
     52     schema = pyjson.dumps(get_schema_from_signature(schema_object))

File ~/dottxt/blog/quarto/intro/.venv/lib/python3.12/site-packages/pydantic/main.py:1214, in BaseModel.parse_raw(cls, b, content_type, encoding, proto, allow_pickle)
   1207     # ctx is missing here, but since we've added `input` to the error, we're not pretending it's the same
   1208     error: pydantic_core.InitErrorDetails = {
   1209         # The type: ignore on the next line is to ignore the requirement of LiteralString
   1210         'type': pydantic_core.PydanticCustomError(type_str, str(exc)),  # type: ignore
   1211         'loc': ('__root__',),
   1212         'input': b,
   1213     }
-> 1214     raise pydantic_core.ValidationError.from_exception_data(cls.__name__, [error])
   1215 return cls.model_validate(obj)

ValidationError: 1 validation error for UserInfo
__root__
  Expecting property name enclosed in double quotes: line 1 column 38 (char 37) [type=value_error.jsondecode, input_value='{ "people": [{"name": "Emily Carter",', input_type=str]

Outlines/Python version information:

Version information

Outlines 0.0.46
Python 3.12.4 (main, Jun  8 2024, 18:29:57) [GCC 11.4.0]

(.venv) λ ~/dottxt/blog/quarto/intro/ main* uv pip freeze
accelerate==0.34.2
aiohappyeyeballs==2.4.0
aiohttp==3.10.5
aiosignal==1.3.1
annotated-types==0.7.0
anyio==4.6.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==24.2.0
babel==2.16.0
beautifulsoup4==4.12.3
bleach==6.1.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.3.2
click==8.1.7
cloudpickle==3.0.0
comm==0.2.2
datasets==3.0.0
debugpy==1.8.5
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.8
diskcache==5.6.3
distro==1.9.0
einops==0.8.0
executing==2.1.0
fastapi==0.115.0
fastjsonschema==2.20.0
filelock==3.16.1
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.6.1
gguf==0.9.1
h11==0.14.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.2
huggingface-hub==0.25.1
idna==3.10
importlib-metadata==8.5.0
interegular==0.3.3
ipykernel==6.29.5
ipython==8.27.0
ipywidgets==8.1.5
isoduration==20.11.0
jedi==0.19.1
jinja2==3.1.4
jiter==0.5.0
json5==0.9.25
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter==1.1.1
jupyter-client==8.6.3
jupyter-console==6.6.3
jupyter-core==5.7.2
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter-server==2.14.2
jupyter-server-terminals==0.5.3
jupyterlab==4.2.5
jupyterlab-pygments==0.3.0
jupyterlab-server==2.27.3
jupyterlab-widgets==3.0.13
lark==1.2.2
llama-cpp-python==0.2.90
llvmlite==0.43.0
lm-format-enforcer==0.10.6
markupsafe==2.1.5
matplotlib-inline==0.1.7
mistral-common==1.4.3
mistune==3.0.2
mpmath==1.3.0
msgpack==1.1.0
msgspec==0.18.6
multidict==6.1.0
multiprocess==0.70.16
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
notebook==7.2.2
notebook-shim==0.2.4
numba==0.60.0
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-ml-py==12.560.30
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.68
nvidia-nvtx-cu12==12.1.105
openai==1.47.1
outlines==0.0.46
overrides==7.7.0
packaging==24.1
pandas==2.2.3
pandocfilters==1.5.1
parso==0.8.4
partial-json-parser==0.2.1.1.post4
pexpect==4.9.0
pillow==10.4.0
platformdirs==4.3.6
prometheus-client==0.21.0
prometheus-fastapi-instrumentator==7.0.0
prompt-toolkit==3.0.47
protobuf==5.28.2
psutil==6.0.0
ptyprocess==0.7.0
pure-eval==0.2.3
py-cpuinfo==9.0.0
pyairports==2.1.1
pyarrow==17.0.0
pycountry==24.6.1
pycparser==2.22
pydantic==2.9.2
pydantic-core==2.23.4
pygments==2.18.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
pytz==2024.2
pyyaml==6.0.2
pyzmq==26.2.0
ray==2.36.0
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.20.0
safetensors==0.4.5
send2trash==1.8.3
sentencepiece==0.2.0
setuptools==75.1.0
six==1.16.0
sniffio==1.3.1
soupsieve==2.6
stack-data==0.6.3
starlette==0.38.6
sympy==1.13.3
terminado==0.18.1
tiktoken==0.7.0
tinycss2==1.3.0
tokenizers==0.19.1
torch==2.4.0
torchvision==0.19.0
tornado==6.4.1
tqdm==4.66.5
traitlets==5.14.3
transformers==4.44.2
triton==3.0.0
types-python-dateutil==2.9.0.20240906
typing-extensions==4.12.2
tzdata==2024.1
uri-template==1.3.0
urllib3==2.2.3
uvicorn==0.30.6
uvloop==0.20.0
vllm==0.6.1.post2
vllm-flash-attn==2.6.1
watchfiles==0.24.0
wcwidth==0.2.13
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
websockets==13.1
widgetsnbextension==4.0.13
xformers==0.0.27.post2
xxhash==3.5.0
yarl==1.11.1
zipp==3.20.2

Context for the issue:

No response

lapp0 commented 1 month ago

Thanks for reporting! I did a quick investigation.

A single-quote isn't generated, that's part of the error messages representation.

Looking at the response metadata, it's stop reason is "length" because SamplingParams.max_tokens == 16.

``` SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, seed=31, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=16, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None) ```

In vllm, the default max_tokens is 16 https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py#L145

I'm wondering if we should default to the tokenizers max length specification if unspecified?

cpfiffer commented 1 month ago

This fixed it! Thank you.

It's a weird default behavior, is there an easy way to use the tokenizer max length?

lapp0 commented 1 month ago

It's a weird default behavior, is there an easy way to use the tokenizer max length?

vLLM prevents generation past model_max_length, so we can simply set max_tokens=None by default.

cpfiffer commented 1 month ago

It seems like it already defaults to max_tokens=None, which strikes me as odd.

https://github.com/dottxt-ai/outlines/blob/30531e58f8e15fd4f47bed5d5718209cbad4b3e0/outlines/generate/api.py#L490-L512

lapp0 commented 1 month ago

Right, but for the vLLM integration, sampling_params.max_tokens isn't changed if max_tokens is None

https://github.com/dottxt-ai/outlines/blob/main/outlines/models/vllm.py#L96-L97