outlines-dev / outlines

Structured Text Generation
https://outlines-dev.github.io/outlines/
Apache License 2.0
7.34k stars 377 forks source link

Support for multi-modal models #662

Closed rlouf closed 1 week ago

rlouf commented 5 months ago

Presentation of the new feature

There are more and more accessible multi-modal models out there, such as llava, and constrained generation applies to every auto-regressive text generation model disregarding their input.

Where does it fit in Outlines?

Maybe the most reasonable way would be to let users pass tuples (prompt, image) to the API functions and use multipledispatch to dispatch both on model and prompt. Or create a new MultimodalModel class and only dispatch on the model type like we currently do.

We need to make sure users can't unknowingly shoot themselves in the foot, the MultimodalModal class would make this easy.

My main concern is that we might need to make generator more complex, or duplicate part of it.

Are you willing to open a PR?

Yes, although I'd appreciate if someone else were willing to take the lead. Happy to help with the design.

lapp0 commented 5 months ago

Here is what the transformers interface looks like

https://github.com/huggingface/transformers/blob/354775bc5755c4a6c47e008d28f27f8ccdcf8f8f/src/transformers/models/llava/modeling_llava.py#L377-L395

    >>> from PIL import Image
    >>> import requests
    >>> from transformers import AutoProcessor, LlavaForConditionalGeneration

    >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
    >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

    >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
    >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> inputs = processor(text=prompt, images=image, return_tensors="pt")

    >>> # Generate
    >>> generate_ids = model.generate(**inputs, max_length=30)
    >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"

inputs.items() contains input_ids, attention_mask, and pixel_values

I agree regarding complexity, the generator would need to manage pixel_values as well. The main difference would be augmenting the attention mask, which would need to be performed in sequence_generator, as this augmentation is applied with every forward pass. How transformers does it:

https://github.com/huggingface/transformers/blob/354775bc5755c4a6c47e008d28f27f8ccdcf8f8f/src/transformers/models/llava/modeling_llava.py#L430-L433

I propose:

Would love to know your thoughts!

Reichenbachian commented 4 months ago

Any updates on this thread?

Kamakshi8104 commented 4 months ago

Hey! I just wanted to know if multimodal models can be used with the connector being implemented in issue #728

rlouf commented 4 months ago

Yes you should be able to use this with multimodal models!