dottxt-ai / outlines

Structured Text Generation
https://dottxt-ai.github.io/outlines/
Apache License 2.0
8.34k stars 425 forks source link

Outlines slows down inference of AQLM models #748

Open remiconnesson opened 6 months ago

remiconnesson commented 6 months ago

Describe the issue as clearly as possible:

The reproduction file load Mixtral quantized with AQLM (this runs on a T4 on collab here : https://colab.research.google.com/drive/119c0tcRWKScoatfIYfKrEdEo8rHe_vE7)

When using outlines to force structure it is noticeably slower than when we dont.

The JSON example at the end never return

Steps/code to reproduce the bug:

"""experiment_JSON_outlines_AQLM_MIXTRAL.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/119c0tcRWKScoatfIYfKrEdEo8rHe_vE7
"""

# !pip install aqlm[gpu,cpu]
# !pip install git+https://github.com/huggingface/accelerate.git@main
# !pip install outlines
# !pip install datasets

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

import outlines

MODEL = "ISTA-DASLab/Mixtral-8x7B-Instruct-v0_1-AQLM-2Bit-1x16-hf"

to_kwargs = lambda **kwargs: kwargs

if "model" not in globals():
    model = outlines.models.transformers(
        model_name=MODEL,
        model_kwargs=to_kwargs(trust_remote_code=True, torch_dtype="auto", device_map="cuda")
    )

prompt = "What is the IP address of the Google DNS servers? "

generator = outlines.generate.text(model)
unstructured = generator(prompt, max_tokens=30)
print(unstructured)

generator = outlines.generate.regex(
    model,
    r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
)
structured = generator(prompt, max_tokens=30)
print(structured)

prompt = "<s>result of 9 + 9 = 18</s><s>result of 1 + 2 = "
answer = outlines.generate.format(model, int)(prompt)
print(answer)

prompt = "sqrt(2)="
generator = outlines.generate.format(model, float)
answer = generator(prompt, max_tokens=10)
print(answer)

import outlines

schema = '''{
    "title": "Character",
    "type": "object",
    "properties": {
        "name": {
            "title": "Name",
            "maxLength": 10,
            "type": "string"
        },
        "age": {
            "title": "Age",
            "type": "integer"
        },
        "armor": {"$ref": "#/definitions/Armor"},
        "weapon": {"$ref": "#/definitions/Weapon"},
        "strength": {
            "title": "Strength",
            "type": "integer"
        }
    },
    "required": ["name", "age", "armor", "weapon", "strength"],
    "definitions": {
        "Armor": {
            "title": "Armor",
            "description": "An enumeration.",
            "enum": ["leather", "chainmail", "plate"],
            "type": "string"
        },
        "Weapon": {
            "title": "Weapon",
            "description": "An enumeration.",
            "enum": ["sword", "axe", "mace", "spear", "bow", "crossbow"],
            "type": "string"
        }
    }
}'''

generator = outlines.generate.json(model, schema)
character = generator("Give me a character description")

Expected result:

I've read that 
> "Outlines does not slow down inference,  but you can incur a small compilation cost at the beginning"
so this is most likely a bug in this case.

Expected results would be to have the same inference time.

Error message:

No response

Outlines/Python version information:

0.0.36 Python 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] accelerate @ git+https://github.com/huggingface/accelerate.git@92d3240bb5ef50ca9aab5d2d38e09ce6bbfc30c0 aiohttp==3.9.3 aiosignal==1.3.1 annotated-types==0.6.0 anyio==4.0.0 aqlm==1.1.2 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asttokens==2.4.1 async-lru==2.0.4 async-timeout==4.0.3 attrs==23.1.0 Babel==2.13.1 beautifulsoup4==4.12.2 bitsandbytes==0.43.0 bleach==6.1.0 blinker==1.4 certifi==2022.12.7 cffi==1.16.0 charset-normalizer==2.1.1 cloudpickle==3.0.0 comm==0.2.0 cryptography==3.4.8 datasets==2.18.0 dbus-python==1.2.18 debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 dill==0.3.8 diskcache==5.6.3 distro==1.7.0 docstring-parser==0.15 entrypoints==0.4 exceptiongroup==1.1.3 executing==2.0.1 fastjsonschema==2.18.1 filelock==3.9.0 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2024.2.0 httplib2==0.20.2 huggingface-hub==0.21.4 idna==3.4 importlib-metadata==4.6.4 interegular==0.3.3 ipykernel==6.26.0 ipython==8.17.2 ipython-genutils==0.2.0 ipywidgets==8.1.1 isoduration==20.11.0 jedi==0.19.1 jeepney==0.7.1 Jinja2==3.1.2 joblib==1.3.2 json5==0.9.14 jsonpointer==2.4 jsonschema==4.19.2 jsonschema-specifications==2023.7.1 jupyter-archive==3.4.0 jupyter-contrib-core==0.4.2 jupyter-contrib-nbextensions==0.7.0 jupyter-events==0.9.0 jupyter-highlight-selected-word==0.2.0 jupyter-lsp==2.2.0 jupyter-nbextensions-configurator==0.6.3 jupyter_client==7.4.9 jupyter_core==5.5.0 jupyter_server==2.10.0 jupyter_server_terminals==0.4.4 jupyterlab==4.0.8 jupyterlab-pygments==0.2.2 jupyterlab-widgets==3.0.9 jupyterlab_server==2.25.0 keyring==23.5.0 lark==1.1.9 launchpadlib==1.10.16 lazr.restfulclient==0.14.4 lazr.uri==1.0.6 llvmlite==0.42.0 lm-format-enforcer==0.9.3 lxml==4.9.3 markdown-it-py==3.0.0 MarkupSafe==2.1.2 matplotlib-inline==0.1.6 mdurl==0.1.2 mistune==3.0.2 more-itertools==8.10.0 mpmath==1.3.0 multidict==6.0.5 multiprocess==0.70.16 nbclassic==1.0.0 nbclient==0.9.0 nbconvert==7.11.0 nbformat==5.9.2 nest-asyncio==1.5.8 networkx==3.0 ninja==1.11.1.1 notebook==6.5.5 notebook_shim==0.2.3 numba==0.59.0 numpy==1.24.1 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.19.3 nvidia-nvjitlink-cu12==12.4.99 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.0 outlines==0.0.36 overrides==7.4.0 packaging==23.2 pandas==2.2.1 pandocfilters==1.5.0 parso==0.8.3 peft @ git+https://github.com/huggingface/peft@6008f272a565f56c146c5d9fd78d00cb24392d7b pexpect==4.8.0 Pillow==9.3.0 platformdirs==3.11.0 prometheus-client==0.18.0 prompt-toolkit==3.0.39 psutil==5.9.6 ptyprocess==0.7.0 pure-eval==0.2.2 pyarrow==15.0.1 pyarrow-hotfix==0.6 pycparser==2.21 pydantic==2.6.4 pydantic_core==2.16.3 Pygments==2.16.1 PyGObject==3.42.1 PyJWT==2.3.0 pyparsing==2.4.7 python-apt==2.4.0+ubuntu2 python-dateutil==2.8.2 python-json-logger==2.0.7 pytz==2024.1 PyYAML==6.0.1 pyzmq==24.0.1 referencing==0.30.2 regex==2023.12.25 requests==2.31.0 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==13.7.1 rpds-py==0.12.0 safetensors==0.4.2 scipy==1.12.0 SecretStorage==3.3.1 Send2Trash==1.8.2 shtab==1.7.1 six==1.16.0 sniffio==1.3.0 soupsieve==2.5 stack-data==0.6.3 sympy==1.12 terminado==0.17.1 tinycss2==1.2.1 tokenizers==0.15.2 tomli==2.0.1 torch==2.2.1 torchaudio==2.1.0+cu118 torchvision==0.16.0+cu118 tornado==6.3.3 tqdm==4.66.2 traitlets==5.13.0 transformers @ git+https://github.com/huggingface/transformers.git@48fbab73303d7b20a4bb6e68548df784ef30a708 triton==2.2.0 trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90 types-python-dateutil==2.8.19.14 typing_extensions==4.10.0 tyro==0.7.3 tzdata==2024.1 uri-template==1.3.0 urllib3==1.26.13 wadllib==1.3.6 wcwidth==0.2.9 webcolors==1.13 webencodings==0.5.1 websocket-client==1.6.4 widgetsnbextension==4.0.9 xxhash==3.4.1 yarl==1.9.4 zipp==1.0.0

Context for the issue:

AQLM opens up running Mixtral on a 16GB gpu (i.e Free Colab), being able to force structured output ouf of Mixtral would be very helpful for individuals and organization with low amount of GPU.

remiconnesson commented 6 months ago

AQLM support in VLLM is close to ready

https://github.com/vllm-project/vllm/pull/3287

When merged I'll check the speed of using (mixtral aqlm) + AQLM + vLLM + outlines is fast enough and if it's fast, I'll close :)