outlines-dev / outlines

Structured Text Generation
https://outlines-dev.github.io/outlines/
Apache License 2.0
6.9k stars 356 forks source link

Vision LLMs and Outlines #787

Open MoritzLaurer opened 2 months ago

MoritzLaurer commented 2 months ago

Vision LLMs like Llava or Idefics are becoming more and more popular for processing both images and text conjointly. For example, Hugging Face will soon release Idefics2 and and Meta's Llama3 might also be multimodal.

Vision LLMs are also just LLMs that produce probabilities/logits over tokens, so my understanding is that they should also be compatible with outlines. The main challenge is that preprocessing is a bit more complicated and Hugging Face handles this with a processor class handles both text and images. These models also have a tokenizer that can be used during decoding/sampling. See e.g. idefics1: docs and usage examples

Two questions:

  1. What would be the best "hacky" way of using outlines with a vision LLM? Would subclassing the SequenceGenerator and changing the __call__ and sequence_generator functions to adapt it to a specific model and processor make sense? Are there other places in the codebase that would be affected? https://github.com/outlines-dev/outlines/blob/main/outlines/generate/api.py#L15

  2. Are you open to integrating Vision LLMs more systematically into outlines?

rlouf commented 2 months ago

Thank you for opening an issue!

Vision LLMs are also just LLMs that produce probabilities/logits over tokens, so my understanding is that they should also be compatible with outlines.

Indeed!

  1. What would be the best "hacky" way of using outlines with a vision LLM? Would subclassing the SequenceGenerator and changing the __call__ and sequence_generator functions to adapt it to a specific model and processor make sense? Are there other places in the codebase that would be affected?

You can probably use the PrefixAllowedTokens classes, as demonstrated in this example

https://github.com/outlines-dev/outlines/blob/main/outlines/generate/api.py#L15

  1. Are you open to integrating Vision LLMs more systematically into outlines?

Yes, I would definitely like to have an easier integration!

MoritzLaurer commented 2 months ago

Thanks for the response. I've dug a bit deeper into this and it seems to be more complicated than that. While standard LLM inputs are ['input_ids', 'attention_mask'], the inputs for idefics are ['input_ids', 'attention_mask', 'pixel_values', 'image_attention_mask'] to account for the image input.

See here:

import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor, BitsAndBytesConfig

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# loading a very small version of idefics for faster testing
checkpoint = "HuggingFaceM4/tiny-random-idefics"  #"HuggingFaceM4/idefics-9b-instruct"
model = IdeficsForVisionText2Text.from_pretrained(
    checkpoint, 
    device_map="auto",
).to(device)

processor = AutoProcessor.from_pretrained(checkpoint)

# We feed to the model an arbitrary sequence of text strings and images. Images can be either URLs or PIL Images.
prompts = [
    [
        "https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG",
        "User: Please identify all entities in this image and return their names as a python list.",
        "<end_of_utterance>",
        "\nAssistant:"
    ],
]

inputs = processor(prompts, return_tensors="pt").to(device)
print(inputs.keys())
# dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'image_attention_mask'])

I've tested updating your Transformers class, sequence_generator and SequenceGenerator to accept these additional inputs and to use the processor class for input preprocessing in addition to the tokenizer, it's currently failing with the image_attention_mask. I assume that the image_attention_mask probably needs to be dynamically updated in SequenceGenerator.__call__ (?), but I don't understand how to do this.

When I call on the rewritten classes, I get this error:

from outlines.models import transformers
import sys
print('outlines.models.transformers' in sys.modules)
import importlib
transformers = sys.modules.get('outlines.models.transformers')
if transformers:
    importlib.reload(transformers)

model_ol = transformers.Transformers(model, processor.tokenizer, processor)

from pydantic import BaseModel
from outlines import generate

import sys
print('outlines.generate' in sys.modules)  # This should return True if it's recognized
import importlib
generate = sys.modules.get('outlines.generate')
if transformers:
    importlib.reload(generate)

class PydanticSchema(BaseModel):
    entity_names: str

prompts = [
    [
        "https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG",
        "User: Please identify all entities in this image and return their names as a python list.",
        "<end_of_utterance>",
        "\nAssistant:"
    ],
]

generator = generate.json(model_ol, PydanticSchema)
result = generator(prompts)
print(result)
{
    "name": "ValueError",
    "message": "Attention mask should be of size (1, 1, 1, 16), but is torch.Size([1, 1, 33, 16])",
    "stack": "---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[25], line 28
     26 #model = models.transformers(\"mistralai/Mistral-7B-v0.1\")
     27 generator = generate.json(model_ol, PydanticSchema)
---> 28 result = generator(prompts)
     29 print(result)
     30 # User(name=\"John\", last_name=\"Doe\", id=11)

File ~/outlines/outlines/outlines/generate/api.py:215, in SequenceGenerator.__call__(self, prompts, max_tokens, stop_at, rng)
    213 while True:
    214     try:
--> 215         last_state = next(states)
    216         if max_tokens or stop_sequences:
    217             token_ids = last_state.token_ids

File ~/outlines/outlines/outlines/generate/generator.py:75, in sequence_generator(model, sampler, fsms, token_ids, sequence_weights, attention_masks, pixel_values, image_attention_mask, fsm_states, rng)
     73 print(\"kv_cache:\", kv_cache)
     74 try:
---> 75     logits, kv_cache = model(token_ids, attention_masks, pixel_values, image_attention_mask, kv_cache)
     76     #print(\"kv_cache:\", kv_cache)
     77 except IndexError:  # Exceeding the context length

File ~/outlines/outlines/outlines/models/transformers.py:175, in Transformers.__call__(self, input_ids, attention_mask, pixel_values, image_attention_mask, past_key_values)
    172     input_ids = input_ids[..., -1].unsqueeze(-1)  
    174 #logits, kv_cache, hidden_image_state = self.model.forward(
--> 175 outputs = self.model.forward(
    176     input_ids,
    177     attention_mask,
    178     None,
    179     past_key_values,
    180     None,
    181     pixel_values,
    182     None,
    183     None,
    184     image_attention_mask,
    185     None,
    186     None,
    187     None,
    188     #interpolate_pos_encoding: Optional[bool] = False,
    189     #return_dict: Optional[bool] = None,
    190 )
    191 #print(outputs)
    192 logits = outputs.logits

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/idefics/modeling_idefics.py:1485, in IdeficsForVisionText2Text.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, image_encoder_embeddings, perceiver_embeddings, image_attention_mask, labels, use_cache, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict)
   1482 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1484 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1485 outputs = self.model(
   1486     input_ids=input_ids,
   1487     attention_mask=attention_mask,
   1488     position_ids=position_ids,
   1489     past_key_values=past_key_values,
   1490     inputs_embeds=inputs_embeds,
   1491     pixel_values=pixel_values,
   1492     image_encoder_embeddings=image_encoder_embeddings,
   1493     perceiver_embeddings=perceiver_embeddings,
   1494     image_attention_mask=image_attention_mask,
   1495     use_cache=use_cache,
   1496     output_attentions=output_attentions,
   1497     output_hidden_states=output_hidden_states,
   1498     interpolate_pos_encoding=interpolate_pos_encoding,
   1499     return_dict=return_dict,
   1500 )
   1502 hidden_states = outputs[0]
   1503 logits = self.lm_head(hidden_states)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/idefics/modeling_idefics.py:1327, in IdeficsModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, image_encoder_embeddings, perceiver_embeddings, image_attention_mask, use_cache, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict)
   1310     layer_outputs = self._gradient_checkpointing_func(
   1311         vblock,
   1312         decoder_layer,
   (...)
   1324         self.gated_cross_attn_layers,
   1325     )
   1326 else:
-> 1327     layer_outputs = vblock(
   1328         decoder_layer,
   1329         hidden_states,
   1330         attention_mask=attention_mask,
   1331         position_ids=position_ids,
   1332         past_key_value=past_key_value,
   1333         image_hidden_states=image_hidden_states,
   1334         image_attention_mask=image_attention_mask,
   1335         cross_attention_gate=cross_attention_gate,
   1336         output_attentions=output_attentions,
   1337         use_cache=use_cache,
   1338         layer_idx=idx,
   1339         cross_layer_interval=self.cross_layer_interval,
   1340         gated_cross_attn_layers=self.gated_cross_attn_layers,
   1341     )
   1343 hidden_states = layer_outputs[0]
   1345 if use_cache:

File /opt/conda/lib/python3.10/site-packages/transformers/models/idefics/modeling_idefics.py:1279, in IdeficsModel.forward.<locals>.vblock(main_block, hidden_states, attention_mask, position_ids, past_key_value, image_hidden_states, image_attention_mask, cross_attention_gate, output_attentions, use_cache, layer_idx, cross_layer_interval, gated_cross_attn_layers)
   1277 if layer_idx % cross_layer_interval == 0:
   1278     xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
-> 1279     outputs = xblock(
   1280         hidden_states,
   1281         attention_mask=attention_mask,
   1282         image_hidden_states=image_hidden_states,
   1283         image_attention_mask=image_attention_mask,
   1284         cross_attention_gate=cross_attention_gate,
   1285         output_attentions=output_attentions,
   1286         use_cache=use_cache,
   1287         past_key_value=None,  # not implemented
   1288     )
   1289     hidden_states = outputs[0]
   1291 layer_outputs = main_block(
   1292     hidden_states,
   1293     attention_mask=attention_mask,
   (...)
   1297     use_cache=use_cache,
   1298 )

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/idefics/modeling_idefics.py:889, in IdeficsGatedCrossAttentionLayer.forward(self, hidden_states, attention_mask, image_hidden_states, image_attention_mask, cross_attention_gate, output_attentions, use_cache, past_key_value)
    886 hidden_states = self.input_layernorm(hidden_states)
    888 # Self Attention
--> 889 hidden_states, self_attn_weights, present_key_value = self.cross_attn(
    890     hidden_states=hidden_states,
    891     key_value_states=image_hidden_states,
    892     attention_mask=image_attention_mask,
    893     output_attentions=output_attentions,
    894 )
    895 hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
    896 # Fill in zeros for cross_attention hidden_states of tokens attending to no images

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/idefics/modeling_idefics.py:658, in IdeficsAttention.forward(self, hidden_states, key_value_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
    656 if attention_mask is not None:
    657     if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
--> 658         raise ValueError(
    659             f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"
    660         )
    662 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
    663 # Reference: https://github.com/pytorch/pytorch/issues/112577.
    664 if query_states.device.type == \"cuda\" and attention_mask is not None:

ValueError: Attention mask should be of size (1, 1, 1, 16), but is torch.Size([1, 1, 33, 16])"
}
rlouf commented 2 months ago

I think that for now it would be much faster / easier to add a logits processor in this file for multimodal models. I am soon going to change the transformers integration to use logits processors (plus add pipelines), and keep SequenceGenerator for our internal experiments with sampling algorithms.

We can add models.idefics and models.llava for now, unless you have the equivalent of AutomodelForCausalLM for multimodal models. These changes should be fairly quick.

I am happy to prioritize this.

scottrblock commented 2 months ago

Hi, I'll just chime in to say I'm very happy to see this is being discussed, and I will likely start experimenting with outlines once those vision models are added, sounds great!

MoritzLaurer commented 3 weeks ago

FYI: Thanks to Outlines, Hugging Face TGI now supports constrained generation with e.g. JSON mode with vision LLMs like Idefics2, see the docs here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/visual_language_models#combining-vision-language-models-with-other-features

houstonlucas commented 2 weeks ago

FYI: Thanks to Outlines, Hugging Face TGI now supports constrained generation with e.g. JSON mode with vision LLMs like Idefics2, see the docs here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/visual_language_models#combining-vision-language-models-with-other-features

On that page they have the link that I expect should go to the example: https://huggingface.co/docs/conceptual/guided-generation But I'm getting a 404. Is there an active page or did they take it down for some reason?

MoritzLaurer commented 2 weeks ago

@houstonlucas I think the dead link is a mistake in the docs actually. it should probably go here: https://huggingface.co/docs/text-generation-inference/conceptual/guidance