vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
21.93k stars 3.1k forks source link

[RFC]: Multi-modality Support Refactoring #4194

Open ywang96 opened 2 months ago

ywang96 commented 2 months ago

Update [6/11] - We have finished our first refactoring milestone - see details here. Some of the items @DarkLight1337, @xwjiang2010 and I are looking to work on as part of the next milestone are tentatively:

API Changes: See latest update here

Performance related

Model support - Add more vision language models, and better developer facing documentation

Some of the ideas that we should work on in the future:

As always, please provide feedback and feature requests in this issue. Suggestions and contributions are very welcomed!


Original RFC Multi-modality support was brought to vLLM recently, much thanks to https://github.com/vllm-project/vllm/pull/3042 from @xwjiang2010. Since then we have seen an increasing amount of interest in such models (from the number of pull requests and issues related). However, there are a few issues we should address with the current design before we bring in more features around multi-modality.

  1. VisionLanguageConfig and MultiModalData

    • Currently the multimodal input can be either pixel_values or image_feaures for simplicity. While this works well with llava 1.5 where pixel_values are the only output from its ClipImageProcessor, this does not work well when it comes to supporting models with more complicated preprocessing to return multiple outputs.(e.g, llava 1.6, fuyu, etc). Developers could add additional preprocessing inside model implementation as a workaround, but this will be unmaintainable over time.

    • The overhead of requiring image_feature_size, image_token_id and image_input_shape is pushed to the user when these can/should be inferred from the model & processor config and not required at the inference time.

  2. The current design assumes multi-modal inputs are already processed to be consumed by the model executable, but vLLM does not have a processor util. This blocks the vision model support on the OpenAI API server for end-to-end inference.

  3. The current prompt format "<Image>" * 576 + prompt makes the underlying implementation easier (especially when it comes to profiling), but complicates the user experience compared to huggingface format "<Image>\n" + prompt and that has caused some confusion on what's needed to make multi-model work on vLLM.

Proposal Most items in the above issues have been discussed and addressed in the original Llava1.5 PR as well as https://github.com/vllm-project/vllm/pull/3978. We propose a few high-level design decisions for the refactoring and welcome any feedback!

  1. Adding a processor util - We can leverage out-of-box AutoProcessor from transformers the same way we have been doing with tokenizer as an attribute of LLMEngine (e.g., self.multi_modal_processor = AutoProcessor(model)). This allows us to support end-to-end inference with the API server as well as the LLM object.

  2. Frontend input format: Because of 1, we can keep the same format as HuggingFace since that's how users usually discover new models and it makes end-to-end integration test easier. Preprocessing should be hidden away from the interface and user. For example, this preprocessing step can be done inside LLMEngine.add_request() around the same place as https://github.com/vllm-project/vllm/blob/a134ef6f5e6c24d3cd459c63557e5db276db25b2/vllm/engine/llm_engine.py#L385-L391 Here's a pesudocode

    
    if multi_modal_input is None:
    prompt_token_ids = self.encode_request( 
       request_id=request_id, 
       prompt=prompt, 
       prompt_token_ids=prompt_token_ids, 
       lora_request=lora_request)
    else:
    # preprocessed_inputs is a dictionary of key(str)-value(tensor)
    # as output of self.multi_modal_processor
    preprocessed_inputs = self.preprocess_request(
       request_id=request_id, 
       prompt=prompt, 
       prompt_token_ids=prompt_token_ids, 
       lora_request=lora_request,
       multi_modal_input=images)
    prompt_token_ids = preprocessed_inputs.pop("input_ids")
    multi_modal_data = MultiModalData(data=preprocessed_inputs)
    ...
and thus at `LLM` level, only image tensors will be required.

4. **Refactor `MultiModalData`**: Now this object simply holds the multi-modal data dictionary that we need for the model_executable. At inference time, data is unpacked in the forward pass - this approach is similar to `transformer` implementation of multi-modal models.
6. **Refactor `VisionLanguageConfig`**: This config is a lot simpler now. One caveat is that sometimes when the image features can be dynamic, users may specify an optional `max_feature_size` to help engine run the profiling for the worst-case scenario as well as to potentially abort certain requests.
7. **Regarding the original `image_feature` as input type design**: IMO LlaVA is a special case among multi-modal models since its vision encoder is detached from the language model and can be initialized separately, but in this case, one could argue that for the MultiModalProjector as well, and perhaps passing image_feature (outputs of CLIP) is a design decision not generalizable to all other models. Instead, passing multi-modal embeddings (outputs of CLIP -> Projector) at inference time is more flexible and should work nicely with other models. (**One followup question is, does it make sense to actually define a separate `Llava-no-clip` module, since this is so specific to llava, to make our life easier?**)

With the above changes, as an end-user, ideally you then should be able to do something like the following

from PIL import Image from vllm import LLM from vllm.config import VisionLanguageConfig

model_id = "llava-hf/llava-v1.6-mistral-7b-hf" llm = LLM(model=model_id, multi_modal_input_type=VisionLanguageConfig.IMAGE_INPUT_TYPE.IMAGE) # This can also be EMBEDDINGS

prompt = "\nUSER: What's the content of the image?\nASSISTANT:"

url = "https://www.ilankelman.org/stopsigns/australia.jpg" image = Image.open(requests.get(url, stream=True).raw)

llm.generate(prompt, ..., multi_modal_input=image)

Under the hood, the pipeline is

prompt, image -> prompt_token_ids, MultiModalData(data=preprocessed_inputs) # through preprocess within engine.add_request() -> prompt_token_ids, pixel_values, image_sizes # though unpacking in implementation of model's forward.


I will follow up with a series of PR for refactoring but please leave any feedback since this is a pretty significant interface change. 

### Report of performance regression

_No response_

### Misc discussion on performance

_No response_

### Your current environment (if you think it is necessary)

```text
The output of `python collect_env.py`
ywang96 commented 2 months ago

cc @DarkLight1337 @Isotr0py @alsichcan

DarkLight1337 commented 2 months ago

Thank you for kickstarting this conversation!

Re: Issues

I fully agree with the issues which you have pointed out. I would like to add that the current prompt format is hardly extensible for multi-image input if we plan to pursue that further down the line. In #3978, I have proposed some ways of tackling the issue at the level of OpenAI-compatible server. I have thought about them more and have decided that they alone cannot provide the required flexibility, as explained below:

If there are only a small number of standard methods, we can provide a config option to choose which method to apply. I have added the image_openai attribute to VisionLanguageConfig to facilitate this.

I am not confident that this assumption would hold for very long, given the fast-changing pace of the field.

A more flexible option would be to pass the image(s) to the chat template (e.g. by setting the images attribute alongside role and content). This transfers the burden of implementation to the maintainers of the model on HuggingFace, making it more likely that vLLM users have to implement their own template. I have created ConversationMessage class to represent the dictionary for each message.

I feel that this should be limited to cases where we only have to pass a single <image> token. The requirement of duplicating image tokens according to feature size should not be a concern of the chat template.

This is not to mention that you still have to manually duplicate the <image> tokens when using vLLM engine directly.

Re: Proposals

Here are my own thoughts on each proposal:

1. Adding a processor util

I think that we should move this responsibility outside of the Engine class. This is because multi-modal input isn't necessarily limited to image data, so we should expect more data types to be added in the future. To avoid having to modify the core Engine logic each time, we can wrap the data with processor objects (with a common interface to process the data) before passing them into the Engine. This way, we can easily add new data types by simply defining a new processor class. For your reference, I have implemented this pattern in #4197.

2. Frontend input format

My comments on this are similar for Proposal 1. However, #4197 only refactors MultiModalData to define data processing logic. To avoid excessive duplication of the logic of encode_request, we should find a way to let MultiModalData control only parts of the process. Also, in my idea of MultiModalData, the processing logic should remain independent of the model architecture. I guess this is where Proposal 3 comes in: HuggingFace processors should output dictionaries with keys that match the parameter names of model.forward().

3. Refactor MultiModalData

I have refactored this class in #4197 according to this description, and it works well enough to support the image_size parameter of LLaVA-NeXT as shown in #4199.

4. Refactor VisionLanguageConfig

Currently in #4197, MultiModalData has to accept ModelConfig and VisionLanguageConfig separately. Perhaps we can make VisionLanguageConfig an attribute of ModelConfig so we do not have to pass in multiple parameters. Using this approach, we only have to add more attributes to ModelConfig instead of having to pass more config objects around in order to support additional multi-modal data types.

Regarding max_feature_size, refer to my comments on Proposal 5.

5. Regarding the original image_feature as input type design

Instead of indirectly specifying the input shapes through the config, we can have each model implement a method to return a dictionary (the required input shape for each keyword argument). For LLaVA, the feature size can be inferred from the HuggingFace config.json if we consider image size and patch size. To support profiling, we can slightly extend this to have the model define the maximum possible input shapes.

Is the unconventional prompt format "<image>" * image_feature_size + prompt mainly to support profiling? While implementing LLaVA-NeXT, I was under the impression that this is used to simplify the generation of the attention masks. Perhaps @xwjiang2010 would have a better idea.

jeejeelee commented 2 months ago

@ywang96 Thanks for driving the integration of more MM models into VLLM. :heart_eyes:

It seems that there is no plan to refactor vision encoder (todo in llava).

In my view, we should prioritize this, with performance being my main consideration.

By refactoring the vision encoder, we can establish an integration standard for MM models, similar to the our LLM models integration . This will not only ensure inference performance but also provide integration guidelines for the community

if I misunderstand, please correct me, thanks for your work again

Isotr0py commented 2 months ago

Generally, I agreed with @DarkLight1337's opinion about moving processing logics out from Engine to prevent modifying core code frequently. However, I think it's difficult to keep the processing logics fully independent from the model architecture.

For example, FuyuProcessor and Idefics2Processor will pad input_ids with image_feature_size during preprocess, while LlavaProcessor won't (I guess this is also why "<image>" * image_feature_size + prompt is used for llava). This means that we need to pad input_ids for llava manually. (maybe there is a better way to handle this? 🤔)

ywang96 commented 2 months ago

cc @robertgshaw2-neuralmagic @mgoin (since NM's planned to work on whisper)

Thank you all for the feedback so far! I plan to address feedback altogether after meeting up with the core devs as well as getting more perspectives from other community members who are working/plan to work on multi-modal models.

Some quick ones that I can answer now:

It seems that there is no plan to refactor vision encoder (todo in llava).

@jeejeelee This will need to be done regardless since it's inside the model implementation, and this RFC is more around how we want to support multi-modal models in general, and thus focuses on the interface and component pattern.

However, I think it's difficult to keep the processing logics fully independent from the model architecture.

@DarkLight1337 @Isotr0py If this is just about where the processor should live, I'm indifferent between having it live inside LLMEngine or not. The tricky part IMO is that then we need to rework on the interface of LLMEngine to consume outputs of AutoProcessor as is.

I was under the impression that this is used to simplify the generation of the attention masks.

@DarkLight1337 That's correct too, but I'm worried that as the model gets more and more complicated, this approach might not be generalizable.

imarcin-rbx commented 2 months ago

Since LLMEngine has support for an output processor interface, e.g. SequenceGroupOutputProcessor. Would it be reasonable within engine, to also add an InputProcessor interface?

This way engine can check for existing of an input processor, but the implementation in this case for llava's single image processing can live outside of engine. It's implementation could be as suggested, based on AutoProcessor.

As for supporting processing of something apart of an Image tag or varying formats - engine could only have a generic input processor executor, within the model executor's code, it would be up to the model implementation to define an input processor and pass it on to engine.

DarkLight1337 commented 2 months ago

Generally, I agreed with @DarkLight1337's opinion about moving processing logics out from Engine to prevent modifying core code frequently. However, I think it's difficult to keep the processing logics fully independent from the model architecture.

For example, FuyuProcessor and Idefics2Processor will pad input_ids with image_feature_size during preprocess, while LlavaProcessor won't (I guess this is also why "<image>" * image_feature_size + prompt is used for llava). This means that we need to pad input_ids for llava manually. (maybe there is a better way to handle this? 🤔)

@Isotr0py Perhaps we could follow a registry pattern and have each model separately register how to preprocess the inputs? If the model does not do so, then the default implementation would be to pass the data to HuggingFace processors.

Isotr0py commented 2 months ago

@Isotr0py Perhaps we could follow a registry pattern and have each model separately register how to preprocess the inputs? If the model does not do so, then the default implementation would be to pass the data to HuggingFace processors.

Yes, I agree that we can use processor registry to solve this. And it seems that transformers_utils/configs could be a good reference for this.

DarkLight1337 commented 2 months ago

@Isotr0py Perhaps we could follow a registry pattern and have each model separately register how to preprocess the inputs? If the model does not do so, then the default implementation would be to pass the data to HuggingFace processors.

Yes, I agree that we can use processor registry to solve this. And it seems that transformers_utils/configs could be a good reference for this.

I have added an implementation of the processor registry to #4197.

Edit: I have also moved the specification of dummy data (for profiling) to the top-level registry. Each model can define its own dummy data by registering a factory function.

DarkLight1337 commented 2 months ago

2. Frontend input format

My comments on this are similar for Proposal 1. However, #4197 only refactors MultiModalData to define data processing logic. To avoid excessive duplication of the logic of encode_request, we should find a way to let MultiModalData control only parts of the process. Also, in my idea of MultiModalData, the processing logic should remain independent of the model architecture. I guess this is where Proposal 3 comes in: HuggingFace processors should output dictionaries with keys that match the parameter names of model.forward().

To solve the prompt format problem for LLaVA, I think we have to also deal with generating the attention masks in the processing framework. That would mean abstracting some of the logic of ModelRunner._prepare_prompt.

DarkLight1337 commented 2 months ago

Just a heads up that #4228 will introduce another vision language model to vLLM, so our discussion should take that into account as well.

ywang96 commented 2 months ago

I discussed with @zhuohan123 offline about this - in particular regarding this comment

To avoid having to modify the core Engine logic each time, we can wrap the data with processor objects (with a common interface to process the data) before passing them into the Engine.

If vLLM's going to use out-of-box AutoProcessor (which includes tokenizer) anyways, then it's logical to make it an attribute of the engine (similar to what we did with tokenizer). As of now for the sake of simplicity, we could add something like self.processor = AutoProcessor(model_id) to this section if the model is an MM model. https://github.com/vllm-project/vllm/blob/15436806912d7ad9371c8bcf6a46857590c107d2/vllm/engine/llm_engine.py#L136-L139

then at inference time, depending on if the request has multi-modal data or not, we process with it with either self.tokenizer or self.processor.

(IMO eventually, there really shouldn't be a separation between how we preprocess text data and multi-modal data as they should all go through one InputProcessor class, but that is probably a bigger engineering refactoring that we can leave for later.)

We can also add an additional parameter on the engine level to indicate that we're feeding the engine an already processed dictionary of tensors, so the preprocessing step with self.processor will be skipped. (Very similar to prompt vs prompt_token_ids)

@DarkLight1337 @Isotr0py WDYT? Do you see any issue with this design?

DarkLight1337 commented 2 months ago

I discussed with @zhuohan123 offline about this - in particular regarding this comment

To avoid having to modify the core Engine logic each time, we can wrap the data with processor objects (with a common interface to process the data) before passing them into the Engine.

If vLLM's going to use out-of-box AutoProcessor (which includes tokenizer) anyways, then it's logical to make it an attribute of the engine (similar to what we did with tokenizer). As of now for the sake of simplicity, we could add something like self.processor = AutoProcessor(model_id) to this section if the model is an MM model.

https://github.com/vllm-project/vllm/blob/15436806912d7ad9371c8bcf6a46857590c107d2/vllm/engine/llm_engine.py#L136-L139

then at inference time, depending on if the request has multi-modal data or not, we process with it with either self.tokenizer or self.processor.

(IMO eventually, there really shouldn't be a separation between how we preprocess text data and multi-modal data as they should all go through one InputProcessor class, but that is probably a bigger engineering refactoring that we can leave for later.)

We can also add an additional parameter on the engine level to indicate that we're feeding the engine an already processed dictionary of tensors, so the preprocessing step with self.processor will be skipped. (Very similar to prompt vs prompt_token_ids)

@DarkLight1337 @Isotr0py WDYT? Do you see any issue with this design?

This is somewhat similar to #4166 where I load the processing logic using AutoProcessor instead of AutoTokenizer for testing the HuggingFace implementation.

I think one potential issue of this design is that the direct dependency on HuggingFace (which we have no control over) would complicate efforts to apply additional preprocessing specific to certain HuggingFace processors (e.g. to adapt to our interface).

Since @Isotr0py 's comment, I have refactored the code in #4197 into using a registry pattern to apply the preprocessor, so that MultiModalData class itself no longer has any preprocessing logic.

ywang96 commented 2 months ago

@DarkLight1337 Thanks for sharing the thoughts! @zhuohan123 and I actually discussed about the use of AutoProcessor.

I think the point is that today vLLM already relies on AutoTokenizer, and most of model implementations we have in vLLM today are based on the implementation of such models in transformers, so I don't really think having this dependency is a big issue. Using AutoProcessor also allows us to abstract away from image in particular so that the same interface will work for other modalities (e.g, whisper) as well.

The original design of the prompt interface isn't very clean, and is very specific to LlaVa-1.5. I would like to emphasize that not every MM model has a "vision tower + projector + LM" architecture, so IMO the input format should really be one of raw inputs (images), processed inputs (outputs of autoprocessor) or embeddings (prompt embeddings + MM embeddings).

I will also be working on a PR so we can cross review each other's work.

zhuohan123 commented 2 months ago

One thing to add is that we would like to keep vLLM's end-user API easy to use. Having AutoProcessor outside of vLLM requires the user to create and pick the correct Processor for the specific model they are using, which can be error-prone. So I lean towards having AutoProcessor in vLLM and an end user can directly feed in the raw image (e.g. like a jpg image) to vLLM.

DarkLight1337 commented 2 months ago

@DarkLight1337 Thanks for sharing the thoughts! @zhuohan123 and I actually discussed about the use of AutoProcessor.

I think the point is that today vLLM already relies on AutoTokenizer, and most of model implementations we have in vLLM today are based on the implementation of such models in transformers, so I don't really think having this dependency is an big issue. Using AutoProcessor also allows us to abstract away from image in particular so that the same interface will work for other modalities (e.g, whisper) as well.

The original design of the prompt interface isn't very clean, and is very specific to LlaVa-1.5. I would like to emphasize that not every MM model has a "vision tower + projector + LM" architecture, so IMO the input format should really be one of raw inputs (images), processed inputs (outputs of autoprocessor) or embeddings (prompt embeddings + MM embeddings).

I will also be working on a PR so we can cross review each other's work.

In this case, we would have to refactor the computation of attention masks so that it can accept single <image> token for LLaVA, since that is what its HuggingFace processor expects. How can we integrate this into vLLM's computation of the attention masks?

Isotr0py commented 2 months ago

Regarding #4228, I think there may be a situation that some MM models don't have a Processor implemented.

In this case, we would have to refactor the computation of attention masks so that it can accept single <image> token for LLaVA, since that is what its HuggingFace processor expects.

@DarkLight1337 IMO, there may be a solution that we can inherit and modify the LLaVA processor to handle num_features calculation and inputs_ids padding etc, so that it can create the right attention masks from current attention masks computation codes.

DarkLight1337 commented 2 months ago

Regarding #4228, I think there may be a situation that some MM models don't have a Processor implemented.

In this case, we would have to refactor the computation of attention masks so that it can accept single <image> token for LLaVA, since that is what its HuggingFace processor expects.

@DarkLight1337 IMO, there may be a solution that we can inherit and modify the LLaVA processor to handle num_features calculation and inputs_ids padding etc, so that it can create the right attention masks from current attention masks computation codes.

I like the idea of simply inheriting from the existing HuggingFace processor. How should we ensure that our implementation is loaded instead of the HuggingFace one?

DarkLight1337 commented 2 months ago

Also, I think that we should wrap the input prompt to LLM.generate in order to better distinguish the kwargs to pass to the HF processor from the other arguments to LLM.generate. It is rather awkward right now that we have to pass a list of multi-modal data with length equal to the input prompts. If we use HF processor directly, the multi-modal inputs would become part of those kwargs instead of a separate MultiModalData instance.

Edit: Opened #4328

DarkLight1337 commented 2 months ago

I have noticed when using distributed inference on LLaVA-NeXT (#4199), there is a bug where the image tokens are not sent to the workers, resulting in an error when trying to merge the vision embeddings. This doesn't happen with LLaVA-1.5 because the model can be loaded inside a single GPU. Does anyone have a setup where LLaVA-1.5 is loaded across multiple GPUs to check whether this issue occurs in the existing vLLM code as well?

Edit: Nevermind, it's just a typo in the chat template I passed to the command for running the OpenAI-compatible server. To avoid such confusion in the future, I have opened #4292 to detect whether the string looks like a file path.

Isotr0py commented 2 months ago

How should we ensure that our implementation is loaded instead of the HuggingFace one?

I think we can refer to get_config() in transformers_utils/config.py, but searching registried processor firstly then AutoProcessor, so that the get_processor() could be:

def get_processor(model: str,
               model_type: str,
               trust_remote_code: bool,
               revision: Optional[str] = None,
               code_revision: Optional[str] = None) -> ProcessorMixin:
    if model_type in _PROCESSOR_REGISTRY:
        processor_class = _PROCESSOR_REGISTRY[model_type]
        processor = processor_class.from_pretrained(model,
                                              revision=revision,
                                              code_revision=code_revision)
        return processor
    try:
        processor = AutoProcessor.from_pretrained(
            model,
            trust_remote_code=trust_remote_code,
            revision=revision,
            code_revision=code_revision)
    except ValueError as e:
        # do something else
DarkLight1337 commented 2 months ago

I think we can refer to get_config() in transformers_utils/config.py, but searching registried processor firstly then AutoProcessor, so that the get_processor() could be:

def get_processor(model: str,
               model_type: str,
               trust_remote_code: bool,
               revision: Optional[str] = None,
               code_revision: Optional[str] = None) -> ProcessorMixin:
    if model_type in _PROCESSOR_REGISTRY:
        processor_class = _PROCESSOR_REGISTRY[model_type]
        processor = processor_class.from_pretrained(model,
                                              revision=revision,
                                              code_revision=code_revision)
        return processor
    try:
        processor = AutoProcessor.from_pretrained(
            model,
            trust_remote_code=trust_remote_code,
            revision=revision,
            code_revision=code_revision)
    except ValueError as e:
        # do something else

To be honest, I'm not a big fan of having to potentially add multiple files in different places* for each new model, but I guess that would work for now. Further down the line, we could consider adopting a more explicit interface for adding new models to vLLM.

*Currently, we have to add a new file in model_executor/models and possibly transformers_utils/configs. After adding multi-modal support, we also have to worry about transformers_utils/processors.

Oliver-ss commented 2 months ago

Hi guys,

It seems the current prefix caching does not work on multi modal model like llava because the hash of block only takes previous token ids and the image patch token is always the same but the corresponding images might differ.

Do you have any ideas about:

  1. How to make the kv block of multi modal request have correct hash id?
  2. How to inherit the kv cache blocks with images? Image might take several blocks but it is impossible to only inherit part of it because the image encoder need to take a whole image.

Thanks!

DarkLight1337 commented 1 month ago

For reference, I have compiled a list of GH issues that are related to this topic (updated periodically):

General issues:

Multi-modal models (help wanted!):

Multi-modal models (in progress):

juberti commented 1 month ago

Hi folks, just stumbled across this issue. It's great to see the proposed steps here as these were all things I ran into when implementing support for our multimodal audio model - in particular, not requiring pre-padding of the input with the special token is a huge simplification when the input (audio in our case) can be of arbitrary size.

The one other thing that might make sense here would be a rename of VisionLanguageConfig to MultimodalLanguageConfig or something similar, rather than having to add separate AudioLanguageConfig and other types in the future throughout all the layers of vLLM. This should be fairly straightforward with your refactoring as I think the only thing remaining in VisionLanguageConfig that we will care about will be the special token id.

DarkLight1337 commented 1 month ago

The current prompt format "<Image>" * 576 + prompt makes the underlying implementation easier (especially when it comes to profiling), but complicates the user experience compared to huggingface format "<Image>\n" + prompt and that has caused some confusion on what's needed to make multi-model work on vLLM.

While messing with the OpenAI GPT-4V API, I found that this is quite vulnerable to token injection where the user includes <image> tokens in the text prompt. This causes the model to crash due to mismatched shapes when filling in the image token positions with the embeddings.

ywang96 commented 1 month ago

For folks who came across this RFC, I have been working closely with @DarkLight1337 on several PRs:

The goal is to support end-to-end GPT4V-compatible inference by the upcoming major release/meetup.

nukes commented 1 month ago

Hi forks, I think when we want to refactor the code, we should not only consider the multi modal input, but also the multi modal output.

The current vllm seems only support text output, what about image/audio output?
Take audio output as example: there is a audio codec decoder(equivalent to text tokenizer decoder) to decode the audio tokens back to audio signal. how can we fuse the audio codec model into vllm to provide a seamless user experience?

Think about gpt4o, can we have a vllm to host the gpt4o level model?

ywang96 commented 1 month ago

Hi forks, I think when we want to refactor the code, we should not only consider the multi modal input, but also the multi modal output.

Hey @nukes! I personally totally agree with this, but for now multi-modal output isn't yet in the scope of the project until there's a reasonable amount of open source model support and interest in it.

As usual, this is an open-sourced project so any contribution/suggestion is welcomed!

AmazDeng commented 1 month ago

Can vllm support direct input of inputs_emb now? If so, we can leverage vllm's inference capabilities with minimal changes to the model inference code. Moreover, since model architectures are diverse, having vllm support direct input of inputs_emb would greatly enhance its applicability. Otherwise, each new model would require redevelopment, which is time-consuming and labor-intensive.

In the generate method, can it support input of embeddings? How is the progress on this?

DarkLight1337 commented 1 month ago

Can vllm support direct input of inputs_emb now? If so, we can leverage vllm's inference capabilities with minimal changes to the model inference code. Moreover, since model architectures are diverse, having vllm support direct input of inputs_emb would greatly enhance its applicability. Otherwise, each new model would require redevelopment, which is time-consuming and labor-intensive.

In the generate method, can it support input of embeddings? How is the progress on this?

I don't think this is even supported for regular LLMs yet. We can work on that after supporting basic multimodal inputs. #5101 might also be of interest to you.

juberti commented 1 month ago

@nukes, once the PRs mentioned by @ywang96 land, I'm going to send a PR to integrate our audio model (https://ultravox.ai) using the same OpenAI interface (ie using image_url). It doesn't handle output yet, but this will get us halfway there.

nukes commented 4 weeks ago

@nukes, once the PRs mentioned by @ywang96 land, I'm going to send a PR to integrate our audio model (https://ultravox.ai) using the same OpenAI interface (ie using image_url). It doesn't handle output yet, but this will get us halfway there.

Nice work, just starred your repo. and joined the discord channel!

DarkLight1337 commented 4 weeks ago

4197 has finally been merged, with additional tests implemented in #5215.

The current prompt format "<Image>" * 576 + prompt makes the underlying implementation easier (especially when it comes to profiling), but complicates the user experience compared to huggingface format "<Image>\n" + prompt and that has caused some confusion on what's needed to make multi-model work on vLLM.

I'm working on a new series of PRs that aim to resolve this issue by giving developers a way to indicate how to insert image tokens into the prompt on a per-model basis (instead of having to specify it via VisionLanguageConfig):

akshay-loci commented 3 weeks ago

@DarkLight1337 @ywang96 Thanks for the ongoing work on improving multimodality support, looking forward to using it! Are these changes going to be part of the next release?

ywang96 commented 3 weeks ago

@DarkLight1337 @ywang96 Thanks for the ongoing work on improving multimodality support, looking forward to using it! Are these changes going to be part of the next release?

@akshay-loci we're on track to merge GPT4V compatible inference (i.e, the user facing changes) mentioned here before the next release. The developer facing changes are still WIP, and I will summarize another roadmap after the release.

For model support, it's best effort for now and I'm trying to see if we can add one or two more before the release.

ywang96 commented 2 weeks ago

For people who are following this thread and multi-modality support in general, we have made an update in the issue at the top so please take a look!

CatherineSue commented 6 days ago

Remove unnecessary attributes from VisionLanguageConfig, such as input_token_id, image_input_shape, and image_feature_size, since these values can be inferred from the HuggingFace config. It is not very friendly to have the user copy the required values from HuggingFace repo.

@DarkLight1337 QQ: how can image_feature_size be found from the HuggingFace repo? I can't seem to find a fixed image_feature_size from model config. Does this issue have a WIP PR?

ywang96 commented 6 days ago

how can image_feature_size be found from the HuggingFace repo?

@CatherineSue To clarify, for now we need image_feature_size because the scheduler need to know in advance how much space the image embeddings take in the final embeddings passed to the language model. This will be no longer needed once we add the prompt preprocessing logic into vLLM (e.g, "<image> what's in this image?" -> "<image>" * 576 + "what's in this image"?). This preprocessing logic can get very complicated especially when the final image embedding shape depends on the image size/resolution itself, and currently this is implemented in the model forward pass in transformers (thus won't work in vLLM).

We're still discussing the best way to design this preprocessing logic into vLLM to make it easier for both users and developers, but ideas are welcomed!

FennFlyer commented 6 days ago

I'm a bit confused about running llava-hf/llava-v1.6-mistral-7b-hf in the latest version of the Docker image. The first issue I had was the config checker looking for several key/value pairs in the config.json under text_config during startup that appear to have been moved to vision_config, such as num_attention_heads. When I copied in the needed key/values from mistralai/Mistral-7B-Instruct-v0.2 config.json, I got further in the startup, but now it's looking for image_size, which I don't see as a parameter defined in the VLM docs page. I also don't see this parameter being passed anywhere in the examples, am I missing something?

ywang96 commented 6 days ago

I'm a bit confused about running llava-hf/llava-v1.6-mistral-7b-hf in the latest version of the Docker image. The first issue I had was the config checker looking for several key/value pairs in the config.json under text_config during startup that appear to have been moved to vision_config, such as num_attention_heads. When I copied in the needed key/values from mistralai/Mistral-7B-Instruct-v0.2 config.json, I got further in the startup, but now it's looking for image_size, which I don't see as a parameter defined in the VLM docs page. I also don't see this parameter being passed anywhere in the examples, am I missing something?

Hey @FennFlyer - dynamic image input shape is currently not supported yet, so you might need to take a look at the example configuration mentioned in #4199. This does mean your input image will be reshaped to the configuration you specified, so the output could be different from huggingface implementation. This is a limitation we're aware of and we're currently working on this!

~EDIT: I realized your issue might be not related to this - I think huggingface recently updated their config files for some of these VLMs so I need to look into this issue..~

ywang96 commented 6 days ago

@FennFlyer I was able to run the example below on the main branch

import requests
from io import BytesIO

from PIL import Image

from vllm import LLM
from vllm.multimodal.image import ImagePixelData
from vllm import SamplingParams

llm = LLM(
    model="llava-hf/llava-v1.6-mistral-7b-hf",
    image_input_type="pixel_values",
    image_token_id=32000,
    image_input_shape="1,3,336,336",
    image_feature_size=1176,
)

prompt = "[INST] " + "<image>" * 1176 + "\nWhat is shown in this image? [/INST]"
url = "https://h2o-release.s3.amazonaws.com/h2ogpt/bigben.jpg"
image = Image.open(BytesIO(requests.get(url).content))
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100)

outputs = llm.generate({
    "prompt": prompt,
    "multi_modal_data": ImagePixelData(image),
}, sampling_params=sampling_params)

generated_text = ""
for o in outputs:
    generated_text += o.outputs[0].text

print(f"LLM output:{generated_text}")

Output:

LLM output: The image shows a nighttime scene of the iconic Big Ben clock tower, also known as the Elizabeth Tower, which is located at the north end of the Palace of Westminster in London, UK. The tower is illuminated and stands out against the dark sky. In the foreground, there is a blurred view of traffic on a busy city street, which suggests the photo was taken from a moving vehicle or in a crowded area where motion blur is present.
chuyishang commented 6 days ago

@ywang96 @DarkLight1337 Hi, not sure if this is the correct place to post about this but thank you for your hard work on multimodal support for vLLM! I'm pretty interested in contributing to the multimodal effort specifically, so I was wondering if you guys had any suggestions on good starter issues to contribute to. I spent the last week familiarizing myself with parts of the codebase, but it seems most of the labeled starter issues don't have to do with multimodal support specifically. If this is not the best place to ask, would it be possible to direct me to the correct place?

FennFlyer commented 6 days ago

Hey @FennFlyer - dynamic image input shape is currently not supported yet, so you might need to take a look at the example configuration mentioned in #4199. This does mean your input image will be reshaped to the configuration you specified, so the output could be different from huggingface implementation. This is a limitation we're aware of and we're currently working on this!

Just want to make sure I'm understanding correctly, hosting this model via Docker is not working right now? I would need to use a standalone script and instantiate an LLM object for a one-shot inference? Here's the relevant bits of my docker-compose if I'm incorrect about that:

command: ["--model", "${VLLM_IMAGE_MODEL_ID}", "--gpu-memory-utilization", "0.75", "--host", "0.0.0.0", "--root-path", "/vllm-server", "--image-input-type", "pixel_values", "--image-token-id", "32000", "--image-input-shape", "1,3,336,336", "--image-feature-size", "576", "--chat-template", "template_llava.jinja"]

Thanks for the assist, I'm currently maintaining an Ollama image just for LlaVa alongside our multiple vLLM instances, so being able to consolidate on one (way more performant) hosting image would be great!

DarkLight1337 commented 6 days ago

Hey @FennFlyer - dynamic image input shape is currently not supported yet, so you might need to take a look at the example configuration mentioned in #4199. This does mean your input image will be reshaped to the configuration you specified, so the output could be different from huggingface implementation. This is a limitation we're aware of and we're currently working on this!

Just want to make sure I'm understanding correctly, hosting this model via Docker is not working right now? I would need to use a standalone script and instantiate an LLM object for a one-shot inference? Here's the relevant bits of my docker-compose if I'm incorrect about that:

command: ["--model", "${VLLM_IMAGE_MODEL_ID}", "--gpu-memory-utilization", "0.75", "--host", "0.0.0.0", "--root-path", "/vllm-server", "--image-input-type", "pixel_values", "--image-token-id", "32000", "--image-input-shape", "1,3,336,336", "--image-feature-size", "576", "--chat-template", "template_llava.jinja"]

Thanks for the assist, I'm currently maintaining an Ollama image just for LlaVa alongside our multiple vLLM instances, so being able to consolidate on one (way more performant) hosting image would be great!

Can you provide more details about the error you're facing? By the way, check that the path to the chat template is correct.

FennFlyer commented 6 days ago

Sure! Here's the error I'm getting. I get all the usual startup output, then hit this error. It definitely looks like I'm passing a starting parameter incorrectly, but I'm having trouble tracing back where that image size that it's checking in llava_next.py is coming from. I haven't modified the vLLM image at all and verified that I'm running vllm/vllm-openai:v0.5.0.post1. I have the chat template in the model folder with all the other Huggingface files and weights.

vllm-llava-server  | [rank0]: Traceback (most recent call last):
vllm-llava-server  | [rank0]:   File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
vllm-llava-server  | [rank0]:     return _run_code(code, main_globals, None,
vllm-llava-server  | [rank0]:   File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
vllm-llava-server  | [rank0]:     exec(code, run_globals)
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 196, in <module>
vllm-llava-server  | [rank0]:     engine = AsyncLLMEngine.from_engine_args(
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 398, in from_engine_args
vllm-llava-server  | [rank0]:     engine = cls(
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 349, in __init__
vllm-llava-server  | [rank0]:     self.engine = self._init_engine(*args, **kwargs)
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 473, in _init_engine
vllm-llava-server  | [rank0]:     return engine_class(*args, **kwargs)
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 236, in __init__
vllm-llava-server  | [rank0]:     self._initialize_kv_caches()
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 313, in _initialize_kv_caches
vllm-llava-server  | [rank0]:     self.model_executor.determine_num_available_blocks())
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 75, in determine_num_available_blocks
vllm-llava-server  | [rank0]:     return self.driver_worker.determine_num_available_blocks()
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
vllm-llava-server  | [rank0]:     return func(*args, **kwargs)
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 162, in determine_num_available_blocks
vllm-llava-server  | [rank0]:     self.model_runner.profile_run()
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
vllm-llava-server  | [rank0]:     return func(*args, **kwargs)
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 844, in profile_run
vllm-llava-server  | [rank0]:     self.execute_model(seqs, kv_caches)
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
vllm-llava-server  | [rank0]:     return func(*args, **kwargs)
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 749, in execute_model
vllm-llava-server  | [rank0]:     hidden_states = model_executable(
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
vllm-llava-server  | [rank0]:     return self._call_impl(*args, **kwargs)
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
vllm-llava-server  | [rank0]:     return forward_call(*args, **kwargs)
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llava_next.py", line 383, in forward
vllm-llava-server  | [rank0]:     image_input = self._parse_and_validate_image_input(**kwargs)
vllm-llava-server  | [rank0]:   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llava_next.py", line 196, in _parse_and_validate_image_input
vllm-llava-server  | [rank0]:     raise ValueError("Incorrect type of image sizes. "
vllm-llava-server  | [rank0]: ValueError: Incorrect type of image sizes. Got type: <class 'NoneType'>
vllm-llava-server exited with code 0
DarkLight1337 commented 6 days ago

Sure! Here's the error I'm getting.

Are you loading the model from a remote HuggingFace repo or from a local one? Which model do you intend to use?

xwjiang2010 commented 5 days ago

Update [6/25]

Synced with @ywang96 and @DarkLight1337 offline. These are a series of tasks for the next week. The goal is to lay the foundation with a better user and developer API. This will greatly benefit the work of supporting other VLM, passing embeddings and adding other modalities down the road. See #5818 for upcoming API changes.

FennFlyer commented 5 days ago

@DarkLight1337 I am loading from a local repo using the same setup as our other local chat models that I use vLLM for. I clone HF repo and point to the directory in the docker-compose with this config:

volumes:
      - ${MODEL_VOL}/${VLLM_IMAGE_MODEL_ID}:/vllm-workspace/${VLLM_IMAGE_MODEL_ID}
command: ["--model", "${VLLM_IMAGE_MODEL_ID}", "--gpu-memory-utilization", "0.75", "--host", "0.0.0.0", "--root-path", "/vllm-server",
      "--image-input-type", "pixel_values", "--image-token-id", "32000", "--image-input-shape", "1,3,336,336", "--image-feature-size", "576",
      "--chat-template", "template_llava.jinja"]

The model ID is loaded from my .env and points to the cloned HF model directory. I am using llava-hf/llava-v1.6-mistral-7b-hf and the directory structure is the following:

config.json
generation_config.json
.git
.gitattributes
model-00001-of-00004.safetensors
model-00002-of-00004.safetensors
model-00003-of-00004.safetensors
model-00004-of-00004.safetensors
model.safetensors.index.json
preprocessor_config.json
README.md
special_tokens_map.json
template_llava.jinja
tokenizer_config.json
tokenizer.json
tokenizer.model

Here is the current config.json I'm using. As I talked about in my initial post, vLLM seems to be looking in text_config for several key/value pairs, but in the most recent pull from HF, those are missing in that section and causes startup to fail at an earlier step, so I added them back in from the mistralai/Mistral-7B-Instruct-v0.2 config.json.

{
  "architectures": [
    "LlavaNextForConditionalGeneration"
  ],
  "ignore_index": -100,
  "image_grid_pinpoints": [
    [
      336,
      672
    ],
    [
      672,
      336
    ],
    [
      672,
      672
    ],
    [
      1008,
      336
    ],
    [
      336,
      1008
    ]
  ],
  "image_token_index": 32000,
  "model_type": "llava_next",
  "projector_hidden_act": "gelu",
  "text_config": {
    "_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2",
    "architectures": [
      "LlavaLlamaForCausalLM"
    ],
    "attention_dropout": 0.0,
    "bos_token_id": 1,
    "eos_token_id": 2,
    "hidden_act": "silu",
    "hidden_size": 4096,
    "initializer_range": 0.02,
    "intermediate_size": 14336,
    "max_position_embeddings": 32768,
    "model_type": "llava",
    "num_attention_heads": 32,
    "num_hidden_layers": 32,
    "num_key_value_heads": 8,
    "rms_norm_eps": 1e-05,
    "rope_theta": 1000000.0,
    "sliding_window": null,
    "tie_word_embeddings": false,
    "torch_dtype": "bfloat16",
    "vocab_size": 32064
  },
  "torch_dtype": "float16",
  "transformers_version": "4.39.0.dev0",
  "use_image_newline_parameter": true,
  "vision_config": {
    "hidden_size": 1024,
    "image_size": 336,
    "intermediate_size": 4096,
    "model_type": "clip_vision_model",
    "num_attention_heads": 16,
    "num_hidden_layers": 24,
    "patch_size": 14,
    "projection_dim": 768,
    "vocab_size": 32000
  },
  "vision_feature_layer": -2,
  "vision_feature_select_strategy": "default",
  "vocab_size": 32064
}
ywang96 commented 5 days ago
llava-hf/llava-v1.6-mistral-7b-hf

@FennFlyer could you open a separate issue so we can keep this issue clean just for discussion around multi-modality refactoring?

BTW - I just checked the latest huggingface repo and it does have the text_config in the config.json file. https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf/blob/main/config.json