sgl-project / sglang

SGLang is a fast serving framework for large language models and vision language models.
https://sglang.readthedocs.io/en/latest/
Apache License 2.0
5.11k stars 358 forks source link

[Feature] Generation Inputs: input_embeds #745

Open AlekseyKorshuk opened 1 month ago

AlekseyKorshuk commented 1 month ago

Motivation

I propose to add input_embeds as an optional input to the generation params.

Why is this important

Nowadays there are a lot of Vision Language Models (VLMs) and they all have similar architecture: vision tower, projector, LLM. This means vision_tower+projector just prepares embeddings for "image" tokens. So why not allow model developers to handle by themselves the preparation of input_embeds for the LLM? Lots of new models tend to allow the user to work with bounding boxes and segmentation masks like PaliGemma and Florence, making it quite complicated to add different processors and conversation templates to the codebase. By allowing the user to provide input_embeds instead of list of messages or text prompts, you reduce your own headache in the future. Another point is that VLM developers can focus on caching image embeddings while building on top of the SGLang, allowing even higher throughput.

vLLM users required this feature long time ago and this topic gained a lot of positive attention from the community:

This unique feature will make the SGLang the main framework for all VLMs.

I am happy to help implement this if you direct me in the codebase and thank you for your time and consideration đŸ¤—

Proposed usages

response = client.chat.completions.create(
    model="default",
    input_embeds=[...],
    temperature=0.8,
    max_tokens=64,
)
backend.run(input_embeds=input_embeds)
@dataclass
class GenerateReqInput:
    # The input prompt. It can be a single prompt or a batch of prompts.
    text: Optional[Union[List[str], str]] = None
    # The token ids for text; one can either specify text or input_ids.
    input_ids: Optional[Union[List[List[int]], List[int]]] = None
    # The embeddings for input_ids; if specified, input_ids should also be provided
    input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
    # The image input. It can be a file name, a url, or base64 encoded string.
    # See also python/sglang/srt/utils.py:load_image.
    image_data: Optional[Union[List[str], str]] = None
    # The sampling_params.
    sampling_params: Union[List[Dict], Dict] = None
    # The request id.
    rid: Optional[Union[List[str], str]] = None
    # Whether to return logprobs.
    return_logprob: Optional[Union[List[bool], bool]] = None
    # The start location of the prompt for return_logprob.
    logprob_start_len: Optional[Union[List[int], int]] = None
    # The number of top logprobs to return.
    top_logprobs_num: Optional[Union[List[int], int]] = None
    # Whether to detokenize tokens in logprobs.
    return_text_in_logprobs: bool = False
    # Whether to stream output.
    stream: bool = False

Related resources

joshpxyne commented 1 month ago

+1!

jsdir commented 1 month ago

+1

tunahfishy commented 1 month ago

!!!

ummagumm-a commented 1 month ago

having this feature would be nice, indeed

merrymercy commented 1 month ago

Great suggestions. Let's prioritize this one. I can share some ideas and pointers.

High-level Idea

Since many parts of the existing code rely on the concept of "input_ids: List[int]," it is not easy to fully change all of them, as this will create many problematic "if/else" conditions. I think one possible implementation idea is to create some random fake "input_ids" to make most of the existing code runnable. Then, during the actual forward pass, we can feed input_embeds instead of calling the embedding layer to encode input_ids.

You can learn more about this idea by looking at how the existing Llava implementation directly feeds input_embeds into the underlying Llama: https://github.com/sgl-project/sglang/blob/0736b270202696b8f865e2915aadc36d3d51811b/python/sglang/srt/models/llava.py#L241-L243 https://github.com/sgl-project/sglang/blob/0736b270202696b8f865e2915aadc36d3d51811b/python/sglang/srt/models/llama2.py#L258-L261

Implementation

The inference of a request starts with GenerateReqInput from the HTTP server, then it will go through several important classes: TokenizerManager, ModelTpServer, ModelRunner, Req, nferBatch. To implement your change, we need to update these places.

  1. Implement your proposed changes to GenerateReqInput https://github.com/sgl-project/sglang/blob/3fdab91912fb271c20642e21c2055df0e23d514e/python/sglang/srt/managers/io_struct.py#L15
  2. Skip the input tokenization in TokenizerManager https://github.com/sgl-project/sglang/blob/0736b270202696b8f865e2915aadc36d3d51811b/python/sglang/srt/managers/tokenizer_manager.py#L142-L148
  3. When creating the Req, record the input_embeds. Maybe here is also a good place to generate the fake input_ids mentioned above. https://github.com/sgl-project/sglang/blob/0736b270202696b8f865e2915aadc36d3d51811b/python/sglang/srt/managers/controller/tp_worker.py#L263
  4. When preparing the inputs of a prefill batch. Save input_embeds into InferBatch. In SGLang, "prefill" is also called "extend". https://github.com/sgl-project/sglang/blob/0736b270202696b8f865e2915aadc36d3d51811b/python/sglang/srt/managers/controller/infer_batch.py#L313
  5. When running the actual forward pass. Feed input_embeds to the model, https://github.com/sgl-project/sglang/blob/0736b270202696b8f865e2915aadc36d3d51811b/python/sglang/srt/managers/controller/model_runner.py#L295-L309

This is my rough idea. I haven't implemented it yet, so there may be some mistakes. I hope it is helpful.

Ying1123 commented 1 month ago

@AlekseyKorshuk any updates?

AlekseyKorshuk commented 1 month ago

Last week was quite busy for me, so unfortunately have not started yet