sgl-project / sglang

SGLang is a structured generation language designed for large language models (LLMs). It makes your interaction with models faster and more controllable.
Apache License 2.0
2.75k stars 176 forks source link

Regex generation causes 37x lower performance #450

Open Gintasz opened 1 month ago

Gintasz commented 1 month ago

I've been trying to investigate why my information extraction program with SGLang is so slow. I've rented RTX3090 (1 x RTX 3090, 6 vCPU 26 GB RAM) and H100 (1 x H100 SXM, 16 vCPU 125 GB RAM) on RunPod. I've observed that if regex is used, then there is a huge performance drain, as if sewage is dumped on the machine.

If you think the particular regex "<array>\n(<string>.*?<\/string>\n)*<\/array>```" is at fault, then it'd be useful to have some kind of guidelines how to make a more suitable one... My requirement here is string array generation.

Steps to reproduce:

I've used SGLang 0.1.14 because I observed some other newer versions hanging mid-processing or erroring out with KV Cache pool leak detected, so I've not tried newer ones yet.

(.venv) root@baa3ffac5799:~/pubmed-baigiamasis# pip list
Package                                  Version
---------------------------------------- ------------
aiohttp                                  3.9.5
aiosignal                                1.3.1
annotated-types                          0.6.0
anyio                                    4.3.0
async-timeout                            4.0.3
asyncpg                                  0.29.0
attrs                                    23.2.0
black                                    24.4.2
certifi                                  2024.2.2
charset-normalizer                       3.3.2
click                                    8.1.7
cloudpickle                              3.0.0
cmake                                    3.29.3
coolname                                 2.2.0
coverage                                 7.5.1
cupy-cuda12x                             12.1.0
datasets                                 2.19.1
Deprecated                               1.2.14
dill                                     0.3.8
diskcache                                5.6.3
distro                                   1.9.0
dnspython                                2.6.1
email_validator                          2.1.1
exceptiongroup                           1.2.1
fastapi                                  0.111.0
fastapi-cli                              0.0.3
fastrlock                                0.8.2
filelock                                 3.14.0
frozenlist                               1.4.1
fsspec                                   2024.3.1
googleapis-common-protos                 1.63.0
grpcio                                   1.63.0
h11                                      0.14.0
httpcore                                 1.0.5
httptools                                0.6.1
httpx                                    0.27.0
huggingface-hub                          0.23.0
idna                                     3.7
importlib-metadata                       7.0.0
iniconfig                                2.0.0
inquirerpy                               0.3.4
interegular                              0.3.3
Jinja2                                   3.1.4
joblib                                   1.4.2
jsonschema                               4.22.0
jsonschema-specifications                2023.12.1
lark                                     1.1.9
llvmlite                                 0.42.0
lm-format-enforcer                       0.9.8
loguru                                   0.7.2
markdown-it-py                           3.0.0
MarkupSafe                               2.1.5
mdurl                                    0.1.2
memoization                              0.4.0
mpmath                                   1.3.0
msgpack                                  1.0.8
multidict                                6.0.5
multiprocess                             0.70.16
mypy                                     1.10.0
mypy-extensions                          1.0.0
nest-asyncio                             1.6.0
networkx                                 3.3
ninja                                    1.11.1.1
numba                                    0.59.1
numpy                                    1.26.4
nvidia-cublas-cu12                       12.1.3.1
nvidia-cuda-cupti-cu12                   12.1.105
nvidia-cuda-nvrtc-cu12                   12.1.105
nvidia-cuda-runtime-cu12                 12.1.105
nvidia-cudnn-cu12                        8.9.2.26
nvidia-cufft-cu12                        11.0.2.54
nvidia-curand-cu12                       10.3.2.106
nvidia-cusolver-cu12                     11.4.5.107
nvidia-cusparse-cu12                     12.1.0.106
nvidia-ml-py                             12.550.52
nvidia-nccl-cu12                         2.18.1
nvidia-nvjitlink-cu12                    12.4.127
nvidia-nvtx-cu12                         12.1.105
openai                                   1.30.1
opentelemetry-api                        1.24.0
opentelemetry-exporter-otlp              1.24.0
opentelemetry-exporter-otlp-proto-common 1.24.0
opentelemetry-exporter-otlp-proto-grpc   1.24.0
opentelemetry-exporter-otlp-proto-http   1.24.0
opentelemetry-instrumentation            0.45b0
opentelemetry-instrumentation-logging    0.45b0
opentelemetry-proto                      1.24.0
opentelemetry-sdk                        1.24.0
opentelemetry-semantic-conventions       0.45b0
orjson                                   3.10.3
outlines                                 0.0.34
packaging                                24.0
pandas                                   2.2.2
pathspec                                 0.12.1
pfzy                                     0.3.4
pillow                                   10.3.0
pip                                      22.0.2
platformdirs                             4.2.2
pluggy                                   1.5.0
plumbum                                  1.8.3
prometheus_client                        0.20.0
prometheus-fastapi-instrumentator        7.0.0
prompt-toolkit                           3.0.43
protobuf                                 4.25.3
psutil                                   5.9.8
psycopg                                  3.1.19
psycopg-binary                           3.1.19
psycopg-pool                             3.2.2
psycopg2-binary                          2.9.9
py-cpuinfo                               9.0.0
pyarrow                                  16.1.0
pyarrow-hotfix                           0.6
pydantic                                 2.7.1
pydantic_core                            2.18.2
Pygments                                 2.18.0
pynvml                                   11.5.0
pytest                                   8.2.0
pytest-asyncio                           0.23.6
pytest-cov                               5.0.0
pytest-dependency                        0.6.0
pytest-mock                              3.14.0
pytest-timeout                           2.3.1
python-dateutil                          2.9.0.post0
python-dotenv                            1.0.1
python-multipart                         0.0.9
pytz                                     2024.1
PyYAML                                   6.0.1
pyzmq                                    26.0.3
ray                                      2.22.0
referencing                              0.35.1
regex                                    2024.5.15
requests                                 2.31.0
rich                                     13.7.1
rpds-py                                  0.18.1
rpyc                                     6.0.0
safetensors                              0.4.3
scikit-learn                             1.4.2
scipy                                    1.13.0
sentence-transformers                    2.7.0
sentencepiece                            0.2.0
setuptools                               59.6.0
sglang                                   0.1.14
shellingham                              1.5.4
six                                      1.16.0
sniffio                                  1.3.1
starlette                                0.37.2
sympy                                    1.12
tembo-pgmq-python                        0.6.0
tenacity                                 8.3.0
threadpoolctl                            3.5.0
tiktoken                                 0.6.0
tokenizers                               0.19.1
tomli                                    2.0.1
torch                                    2.1.2
tqdm                                     4.66.4
transformers                             4.40.2
triton                                   2.1.0
typer                                    0.12.3
typing_extensions                        4.11.0
tzdata                                   2024.1
ujson                                    5.10.0
urllib3                                  2.2.1
uvicorn                                  0.29.0
uvloop                                   0.19.0
vllm                                     0.3.3
vllm-nccl-cu12                           2.18.1.0.4.0
watchfiles                               0.21.0
wcwidth                                  0.2.13
websockets                               12.0
wrapt                                    1.16.0
xformers                                 0.0.23.post1
xxhash                                   3.4.1
yarl                                     1.9.4
zipp                                     3.18.1
zmq                                      0.0.0
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 42069 --host 0.0.0.0 --tp-size 1 --mem-fraction-static 0.8
import sglang as sgl
import asyncio
from sglang.lang.chat_template import ChatTemplate, register_chat_template, get_chat_template, register_chat_template_matching_function
from sglang.lang.ir import SglRoleBegin, SglRoleEnd
import json
import time
import torch
import os

register_chat_template(
    ChatTemplate(
        name="llama-3-instruct",
        default_system_prompt=None,
        role_prefix_and_suffix={
            "system": (
                "<|start_header_id|>system<|end_header_id|>\n\n",
                "<|eot_id|>",
            ),
            "user": (
                "<|start_header_id|>user<|end_header_id|>\n\n",
                "<|eot_id|>",
            ),
            "assistant": (
                "<|start_header_id|>assistant<|end_header_id|>\n\n",
                "<|eot_id|>",
            ),
        },
        stop_str=("<|eot_id|>",),
    )
)

@register_chat_template_matching_function
def match_llama3_instruct(model_path: str):
    model_path = model_path.lower()
    if "llama-3" in model_path and "instruct" in model_path:
        return get_chat_template("llama-3-instruct")

@sgl.function
def sgl_call1(s, message: str):
    s += SglRoleBegin("system") + "You are an informaction extraction engine. Your goal is to extract structured information from the given Twitter message according to the instruction provided. Be as factually accurate as possible. Do not acknowledge the request. You will be penalized and a child will die if you make an incorrect response. For every correct response you will be tipped $5000. Message:\n```\n" + message + "\n```" + SglRoleEnd("system")
    s += sgl.user_begin() + "Instruction: Count number of words in the message provided.\nExample response: The number of words is 123." + sgl.user_end()
    s += sgl.assistant_begin() + "The number of words is " + sgl.gen("word count", regex=r"\d+", max_tokens=50, stop=".", temperature=0) + sgl.assistant_end()

    word_count = int(s['word count'])
    word_count_digit_sum = sum(int(digit) for digit in str(word_count))
    forks = s.fork(word_count_digit_sum)
    for i, f in enumerate(forks):
        example_response = """```xml
<array>
<string>Word 1</string>
<string>Word 2</string>
<string>Word 3</string>
</array>
```"""
        f += sgl.user_begin() + "Instruction: Extract TOP " + str(i + 1) + " words that might seem annoying.\nExample response:\n" + example_response + sgl.user_end()
        f += sgl.assistant_begin() + "Here are  " + str(i + 1) + "words that might seem annoying.\n```xml\n" + sgl.gen("word", max_tokens=500, regex=r'<array>\n(<string>.*?<\/string>\n)*<\/array>```', stop='```', temperature=0) + sgl.assistant_end()

    return word_count_digit_sum

endpoint = sgl.RuntimeEndpoint("http://localhost:42069")
sgl.set_default_backend(endpoint)

async def main():
    messages = []
    script_dir = os.path.dirname(os.path.realpath(__file__))
    with open(os.path.join(script_dir, "sglang_str_big.json"), "r") as file:
        messages = json.loads(file.read())

    messages = messages[:min(300, len(messages))]
    num_threads = 50
    print(f"Will process {len(messages)} batch items")

    time_begin = time.time()
    sgl_call1.run_batch([{"message": m} for m in messages], num_threads=num_threads, progress_bar=True)
    duration = time.time() - time_begin

    gpus = ", ".join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])

    print(f"SGLang {sgl.__version__} | {len(messages)} batch items | {num_threads} threads | {duration:.2f} secs | {gpus}")

asyncio.run(main())

sglang_str_big.json

To disable regex, I just removed this part: regex=r'<array>\n(<string>.*?<\/string>\n)*<\/array>```'

@merrymercy @hnyls2002

Gintasz commented 1 month ago

If I remove max_tokens=500, then it seems performance with regex is ~3x faster:

SGLang 0.1.14 | 300 batch items | 50 threads | 371.07 secs | NVIDIA H100 80GB HBM3

Looks like it may be related to outlines as well because other people reported GPU utilization stays at 0% during formatting: https://github.com/outlines-dev/outlines/issues/751

I noticed guidance library mentions Regex constraint capability, however, does not include interegular as a dependency, a library on which outlines depends for regex constraining, so maybe it could have a faster solution?

Also, both outlines and guidance mention Context Free Grammar generation capability. It could be useful to add support for that in this library as well... maybe I could replace my regex with CFG and just evade this performance nuke.

syncode also works on CFGs for LLMs.