sgl-project / sglang

SGLang is a fast serving framework for large language models and vision language models.
https://sgl-project.github.io/
Apache License 2.0
6.32k stars 555 forks source link

[Bug][minimal reproducible demo] High variability across batch inference runs #1729

Open FredericOdermatt opened 1 month ago

FredericOdermatt commented 1 month ago

Checklist

Describe the bug

Background

This bug might be related to #1316.

When asking the model a block of questions it should answer with yes followed by a block of questions that should be answered by no a degradation in quality can be observed for some runs, when running the same data many times.

Standard lmsysorg/sglang:v0.3.3.post1-cu121-srt

Asking 200 times the same 40 yes, 40 no questions and recording logit averages. Blue: questions that should be answered yes: average yes logit (post-softmax) Orange: questions that should be answered no: average yes logit (post-softmax). (please check the minimal reproducible sample here)

image

Restricted lmsysorg/sglang:v0.3.3.post1-cu121-srt

Adding the following flags and running 100 times:

--attention-backend triton --sampling-backend pytorch --disable-radix-cache --disable-regex-jump-forward --disable-cuda-graph --disable-cuda-graph-padding --disable-disk-cache --disable-custom-all-reduce --disable-mla

image

Observations

Further notes

Reproduction

Current minimal reproducible example here

Normal server start

python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x22B-Instruct-v0.1 --random-seed 42 --tp-size 8 --dp-size 1 --host 0.0.0.0 --port 30001

Restricted server start python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x22B-Instruct-v0.1 --attention-backend triton --sampling-backend pytorch --disable-radix-cache --disable-regex-jump-forward --disable-cuda-graph --disable-cuda-graph-padding --disable-disk-cache --disable-custom-all-reduce --disable-mla --random-seed 42 --tp-size 8 --dp-size 1 --host 0.0.0.0 --port 30001

Environment

Environment for problematic runs lmsysorg/sglang:v0.3.3.post1-cu121-srt

jonzhep commented 1 month ago

also able to confirm this, also get it with flashinfer on vllm,

merrymercy commented 1 month ago

@FredericOdermatt @jonzhep This is very helpful. We will take a close look this week and hopefully fix it soon.

merrymercy commented 1 month ago

This is what I got when running your example commands (Normal server start) on 8xH100 with the current main (87a7cfa080cec3f123618c1429)

yes_no_logits

It can basically reproduce what you said, although not as bad as what you show. I will start investigation. May I know the hardware you are using? You can also get that by running python3 -m sglang.check_env

FredericOdermatt commented 1 month ago

I was running this on either 8 RTX A6000, or 4 A100's. The plot above is from the RTX's

python3 -m sglang.check_env ``` Python: 3.10.15 (main, Sep 7 2024, 18:35:33) [GCC 9.4.0] CUDA available: True GPU 0,1,2,3,4,5,6,7: NVIDIA RTX A6000 GPU 0,1,2,3,4,5,6,7 Compute Capability: 8.6 CUDA_HOME: /usr/local/cuda NVCC: Cuda compilation tools, release 12.1, V12.1.105 CUDA Driver Version: 535.183.06 PyTorch: 2.4.0+cu121 flashinfer: 0.1.6+cu121torch2.4 triton: 3.0.0 transformers: 4.45.2 requests: 2.32.3 tqdm: 4.66.5 numpy: 1.26.4 aiohttp: 3.10.10 fastapi: 0.115.0 hf_transfer: 0.1.8 huggingface_hub: 0.25.2 interegular: 0.3.3 packaging: 24.1 PIL: 10.4.0 psutil: 6.0.0 pydantic: 2.9.2 uvicorn: 0.31.1 uvloop: 0.20.0 zmq: 26.2.0 vllm: 0.5.5 multipart: 0.0.12 openai: 1.51.2 tiktoken: 0.8.0 anthropic: Module Not Found litellm: Module Not Found NVIDIA Topology: GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 CPU Affinity NUMA Affinity GPU NUMA ID GPU0 X NV4 NODE NODE SYS SYS SYS SYS 0-63,128-191 0 N/A GPU1 NV4 X NODE NODE SYS SYS SYS SYS 0-63,128-191 0 N/A GPU2 NODE NODE X NV4 SYS SYS SYS SYS 0-63,128-191 0 N/A GPU3 NODE NODE NV4 X SYS SYS SYS SYS 0-63,128-191 0 N/A GPU4 SYS SYS SYS SYS X NV4 NODE NODE 64-127,192-254 1 N/A GPU5 SYS SYS SYS SYS NV4 X NODE NODE 64-127,192-254 1 N/A GPU6 SYS SYS SYS SYS NODE NODE X NV4 64-127,192-254 1 N/A GPU7 SYS SYS SYS SYS NODE NODE NV4 X 64-127,192-254 1 N/A Legend: X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks ulimit soft: 1048576 ```
merrymercy commented 1 month ago

This has been one of the biggest issues we've known about for a while. In short, I believe that dynamic batching introduces these variances because different batch sizes dispatch different kernels. We checked the engine implementation and did not find any noticeable bugs (e.g., incorrect caching). We will continue investigating and may introduce a "deterministic mode" as a short-term solution. This mode will use additional padding to increase determinism, although it will run more slowly.

cermeng commented 3 weeks ago

This has been one of the biggest issues we've known about for a while. In short, I believe that dynamic batching introduces these variances because different batch sizes dispatch different kernels. We checked the engine implementation and did not find any noticeable bugs (e.g., incorrect caching). We will continue investigating and may introduce a "deterministic mode" as a short-term solution. This mode will use additional padding to increase determinism, although it will run more slowly.

This is very helpful! I raised a similar issue in vllm https://github.com/vllm-project/vllm/issues/10074 and I think this is the same reason for that.

BTW, I believe that chunked prefill may increase the likelihood of the variance, as I've observed in my case with vllm. The default strategy in vllm which uses first-come-first-serve and prioritizes prefill requests, tends to mask this variance(batch size should be more likely to be consistent between two prefill execution at separate runs)