flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
822 stars 77 forks source link

Add group_size 7 and fix compat with Yi 1.5 34b #246

Closed Qubitium closed 2 months ago

Qubitium commented 2 months ago

Allow users to compile for group_size 7 and have it compatible with Yi 1.0/1.5 34B models.

Fix https://github.com/flashinfer-ai/flashinfer/issues/181

I did not modify the CI script to include group_size 7 by default as it would increase the compile time to even longer than it already is. Users that want Yi-34B compat can opt to compile only group_size 7 via env var FLASHINFER_GROUP_SIZES (fastest, maybe 10-15 minutes) or add 7 to the existing 1,4,6,8 for full compat with other model (super slow compile, measured in hours).

Yi 1.5 is a great model and this will help those engines (sglang) that uses flashinfer to deploy this model.

Tests

xuzhenqi commented 2 months ago

@Qubitium Do you test correctness of Yi-model? I tried this method before, and found BatchPrefill kernel does not return correct outputs. I also fixed BatchPrefill kernel in #223 .

Qubitium commented 2 months ago

@Qubitium Do you test correctness of Yi-model? I tried this method before, and found BatchPrefill kernel does not return correct outputs. I also fixed BatchPrefill kernel in #223 .

@xuzhenqi Going to do some expanded human eval using this PR later today and will let you know the results. Btw, #223 looks great and appears to be more generic/mid-term solution to this group size issue until dynamic group-size is implemented.

Qubitium commented 2 months ago

@xuzhenqi You are right. There are output instability when we use temperature=0.7 coupled with sglang:

Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model.py", line 187, in exposed_step
    self.forward_step()
  File "/root/miniconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model.py", line 202, in forward_step
    self.forward_fill_batch(new_batch)
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model.py", line 424, in forward_fill_batch
    next_token_ids, _ = batch.sample(logits)
                        ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/infer_batch.py", line 552, in sample
    sampled_index = torch.multinomial(probs_sort, num_samples=1)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Now we will compile and test your #223 PR instead.