LLaVA-VL / LLaVA-NeXT

Apache License 2.0
2.99k stars 257 forks source link

[BUG] Function `prepare_inputs_labels_for_multimodal` flattens batch data #153

Open guyazran opened 3 months ago

guyazran commented 3 months ago

In the file llava/model/llava_arch.py under the class LlavaMetaForCausalLM there is a functionprepare_inputs_labels_for_multimodal that is called when calling the generate and forward functions. In lines 411 and 412, the input embeds change shape: new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] when I run with images, modalities is simply the list ["images"] and so if there are multiple inputs in new_input_embeds then they are skipped. removing the modalities from these lines fixes the issue for me.

VittorioRossi commented 3 months ago

I encountered a problem with batch processing while using the LlavaQwenForCausalLM model. Thanks to @guyazran's suggestion, I found a solution.

The optimal approach is to provide a list of all modalities to the 'modalities' parameter. I don't think this is a bug, but rather a deliberate design choice by the developer. It allows for flexibility when batching includes various modalities simultaneously.

To summarize: For correct batch processing, we need to add an entry in the 'modalities' parameter for each prompt in the batch when calling the generate function.

guyazran commented 3 months ago

Thanks @VittorioRossi I hadn't thought about it like that. This makes perfect sense. But still, the default removes all the rest of the batch without any warning or explanation. This is a confusing behavior, especially because the place where this happens in the (lines 411 and 412) doesn't actually use the modality output. If the point is to limit the batch size to match the number of provided modalities on purpose, then I suggest the implementing one of these sensible behaviors: