vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
27.71k stars 4.09k forks source link

[Feature Request] Support input embedding in `LLM.generate()` #416

Open KimmiShi opened 1 year ago

KimmiShi commented 1 year ago

Hi, I am using llm as part of a multimodal model, so the model needs to pass input embedding tensor directly to generate, and also need to access the language model's embed_tokens member to fist calculate the embedding, and then processed, finnaly send to generate, demo in the following code :

        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)

        prefix_embeds = inputs_embeds[:, :self.offset, :]
        postfix_embeds = inputs_embeds[:, self.offset:, :]
        inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)

        .....
        attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)

        outputs = self.language_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            generation_config=generation_config,
            **generate_kwargs,
        )

I read the vllm code, and it seems that I need to add two interfaces in vllm, one is LLM.get_input_embeddings, another one is LLM.generate(inputs_embeds=inputs_embeds, ...)

Do you think this will work? And would you consider support this feature?

KimmiShi commented 1 year ago

It seems that worker._prepare_inputs method need to be modified to support embedding tensor input, can you support this feature?

hangzhang-nlp commented 1 year ago

Awesome work!! And I have the same need.

zacharyblank commented 1 year ago

Has there been any progress on this? I am looking to achieve something very similar: essentially i need to be able to pass in a previously calculated embedding as to not have to recalculate it as part of a common prompt. I have somewhere between 4 - 12k tokens that are currently being reprocessed many times for a single request due to my use case.

pfldy2850 commented 9 months ago

I've been working on the PR https://github.com/vllm-project/vllm/pull/1265 for this issue. I am facing the following issue in my work.

Two tensors that should logically look the same actually have different values. This happens when stacking multiple operations to support generation from input_embeddings, which leads to different results than output of prompt-based generation.

test_ids = torch.tensor([[   85,  3069,    44,   318,   257,  1029,    12,  9579,  1996,   290,
          4088,    12, 16814, 32278,   290,  7351,  3113,   329, 27140, 10128,
            13,     0,     0,     0]]).to("cuda")
zero_ids = torch.zeros(test_ids.shape, dtype=test_ids.dtype).to("cuda")
embeds = input_embeddings(test_ids)
embeds_2 = input_embeddings(test_ids) - input_embeddings(zero_ids) + input_embeddings(zero_ids)

is_equal = torch.equal(embeds, embeds_2)
print(is_equal)

# printed: False

Can anyone give me some hints to solve these problems?

hmellor commented 6 months ago

@WoosukKwon is this on the roadmap?

Andcircle commented 2 months ago

what is current status on it =) Also need this feature

hmellor commented 2 months ago

According to https://github.com/vllm-project/vllm/pull/1265#issuecomment-2015813846 this feature was added in #3042

Andcircle commented 2 months ago

@hmellor

3042 is not the feature we actually expected,

Is it possible we can add feature like:


llm = LLM(model="mistral", ...)
inputs_embeds = merge_inputs(texts, images)
#merge_inputs is a customized function, provided by user, this will make the process more flexible
outputs = llm.generate(inputs_embeds=inputs_embeds, ...)```
AnyangAngus commented 1 month ago

@hmellor #3042 is not the feature we actually expected,

Is it possible we can add feature like:

llm = LLM(model="mistral", ...)
inputs_embeds = merge_inputs(texts, images)
#merge_inputs is a customized function, provided by user, this will make the process more flexible
outputs = llm.generate(inputs_embeds=inputs_embeds, ...)```

yeah, same request, input embedding in the llm.generate() function may be straightforward

DarkLight1337 commented 1 month ago

FYI this is now supported for multi-modal models via #6613. Perhaps a similar idea could be used to extend this to language-only models.

Andcircle commented 1 month ago

FYI this is now supported for multi-modal models via #6613. Perhaps a similar idea could be used to extend this to language-only models.

Thanks for @DarkLight1337, thanks for the updates. I checked the demo code, we still provide 2 modality separately, prompt and images, and the merge process is still controlled by only the VLLM supported VLM model, it is not that flexible if we wanna our own merge methods.

Can we do as following: so we just take customized VLM as a PURE language model

#start from pure language model, NOT existing VLM
llm = LLM(model="mistral", ...)
#merge_inputs is a customized function, provided by user, this will make the process more flexible
inputs_embeds = merge_inputs(texts, images)
#the LLM only takes batch of merged embeddings, it doesn't care is it image / video / audio anymore, it just take it as pure language model
outputs = llm.generate(inputs_embeds=inputs_embeds, ...)

I saw you also mentioned: Follow-up TODO: Support initializing VLM with only language model backbone. Is this as above mentioned?

Again, really really appreciated your help.