sgl-project / sglang

SGLang is a fast serving framework for large language models and vision language models.
https://sgl-project.github.io/
Apache License 2.0
5.9k stars 477 forks source link

`run_batch()` RuntimeError: Trying to create tensor with negative dimension #55

Closed niklub closed 9 months ago

niklub commented 9 months ago

Hello, team!

Thanks for the excellent work. When working batch inference, sometimes encountering server-side error that completely interrupts the process:

new fill batch. #seq: 7. #cached_token: 1541. #new_token: 4. #remaining_req: 0. #running_req: 0
Exception in ModelRpcClient:
Traceback (most recent call last):
  File "/opt/conda/envs/env/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 130, in exposed_step
    self.forward_step()
  File "/opt/conda/envs/env/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 145, in forward_step
    self.forward_fill_batch(new_batch)
  File "/opt/conda/envs/env/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 328, in forward_fill_batch
    logits, normalized_logprobs = self.model_runner.forward(
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.11/site-packages/sglang/srt/managers/router/model_runner.py", line 453, in forward
    return self.forward_extend(**kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.11/site-packages/sglang/srt/managers/router/model_runner.py", line 359, in forward_extend
    return self.model.forward(input_ids, input_metadata.positions, input_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.11/site-packages/sglang/srt/models/llama2.py", line 270, in forward
    return self.logits_processor(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.11/site-packages/sglang/srt/layers/logits_processor.py", line 55, in forward
    normalized_logprobs = compute_normalized_logprobs(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/env/lib/python3.11/site-packages/sglang/srt/layers/logits_processor.py", line 67, in compute_normalized_logprobs
    logprobs = torch.zeros(
               ^^^^^^^^^^^^
RuntimeError: Trying to create tensor with negative dimension -3: [-3]

It never happened to me with small batch sizes (1-10), but constantly face it with bigger ones.

The code to run batch inference fwiw:

@sgl.function
def answer(s, question):
    s += question + '\n'
    s += sgl.gen("answer", choices=['Y', 'N'], temperature=0)

def driver_batching(questions):
    states = answer.run_batch(
        [{'question': q} for q in questions]
    )
    return [state['answer'] for state in states]

sgl.set_default_backend(sgl.RuntimeEndpoint(MY_URL))

from tqdm import tqdm
output = []
batch_size = 50
for i in tqdm(range(0, len(prompts), batch_size)):
  output.extend(driver_batching(prompts[i:i+batch_size]))
merrymercy commented 9 months ago

Thanks for reporting this bug. We will look into it soon.

In the meantime, you can probably try

sgl.gen("answer", regex=r"(Y|N)", temperature=0)

which is more robust and can work well for a few short choices.

merrymercy commented 9 months ago

This bug is fixed by #67 You can try the latest main branch or sglang[all]>=1.6.0