Closed MoritzLaurer closed 3 months ago
Thank you for opening an issue!
Vision LLMs are also just LLMs that produce probabilities/logits over tokens, so my understanding is that they should also be compatible with outlines.
Indeed!
- What would be the best "hacky" way of using outlines with a vision LLM? Would subclassing the
SequenceGenerator
and changing the__call__
andsequence_generator
functions to adapt it to a specific model and processor make sense? Are there other places in the codebase that would be affected?
You can probably use the PrefixAllowedTokens
classes, as demonstrated in this example
https://github.com/outlines-dev/outlines/blob/main/outlines/generate/api.py#L15
- Are you open to integrating Vision LLMs more systematically into outlines?
Yes, I would definitely like to have an easier integration!
Thanks for the response. I've dug a bit deeper into this and it seems to be more complicated than that. While standard LLM inputs are ['input_ids', 'attention_mask'], the inputs for idefics
are ['input_ids', 'attention_mask', 'pixel_values', 'image_attention_mask'] to account for the image input.
See here:
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor, BitsAndBytesConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
# loading a very small version of idefics for faster testing
checkpoint = "HuggingFaceM4/tiny-random-idefics" #"HuggingFaceM4/idefics-9b-instruct"
model = IdeficsForVisionText2Text.from_pretrained(
checkpoint,
device_map="auto",
).to(device)
processor = AutoProcessor.from_pretrained(checkpoint)
# We feed to the model an arbitrary sequence of text strings and images. Images can be either URLs or PIL Images.
prompts = [
[
"https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG",
"User: Please identify all entities in this image and return their names as a python list.",
"<end_of_utterance>",
"\nAssistant:"
],
]
inputs = processor(prompts, return_tensors="pt").to(device)
print(inputs.keys())
# dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'image_attention_mask'])
I've tested updating your Transformers class, sequence_generator and SequenceGenerator to accept these additional inputs and to use the processor
class for input preprocessing in addition to the tokenizer, it's currently failing with the image_attention_mask
. I assume that the image_attention_mask
probably needs to be dynamically updated in SequenceGenerator.__call__
(?), but I don't understand how to do this.
When I call on the rewritten classes, I get this error:
from outlines.models import transformers
import sys
print('outlines.models.transformers' in sys.modules)
import importlib
transformers = sys.modules.get('outlines.models.transformers')
if transformers:
importlib.reload(transformers)
model_ol = transformers.Transformers(model, processor.tokenizer, processor)
from pydantic import BaseModel
from outlines import generate
import sys
print('outlines.generate' in sys.modules) # This should return True if it's recognized
import importlib
generate = sys.modules.get('outlines.generate')
if transformers:
importlib.reload(generate)
class PydanticSchema(BaseModel):
entity_names: str
prompts = [
[
"https://upload.wikimedia.org/wikipedia/commons/8/86/Id%C3%A9fix.JPG",
"User: Please identify all entities in this image and return their names as a python list.",
"<end_of_utterance>",
"\nAssistant:"
],
]
generator = generate.json(model_ol, PydanticSchema)
result = generator(prompts)
print(result)
{
"name": "ValueError",
"message": "Attention mask should be of size (1, 1, 1, 16), but is torch.Size([1, 1, 33, 16])",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[25], line 28
26 #model = models.transformers(\"mistralai/Mistral-7B-v0.1\")
27 generator = generate.json(model_ol, PydanticSchema)
---> 28 result = generator(prompts)
29 print(result)
30 # User(name=\"John\", last_name=\"Doe\", id=11)
File ~/outlines/outlines/outlines/generate/api.py:215, in SequenceGenerator.__call__(self, prompts, max_tokens, stop_at, rng)
213 while True:
214 try:
--> 215 last_state = next(states)
216 if max_tokens or stop_sequences:
217 token_ids = last_state.token_ids
File ~/outlines/outlines/outlines/generate/generator.py:75, in sequence_generator(model, sampler, fsms, token_ids, sequence_weights, attention_masks, pixel_values, image_attention_mask, fsm_states, rng)
73 print(\"kv_cache:\", kv_cache)
74 try:
---> 75 logits, kv_cache = model(token_ids, attention_masks, pixel_values, image_attention_mask, kv_cache)
76 #print(\"kv_cache:\", kv_cache)
77 except IndexError: # Exceeding the context length
File ~/outlines/outlines/outlines/models/transformers.py:175, in Transformers.__call__(self, input_ids, attention_mask, pixel_values, image_attention_mask, past_key_values)
172 input_ids = input_ids[..., -1].unsqueeze(-1)
174 #logits, kv_cache, hidden_image_state = self.model.forward(
--> 175 outputs = self.model.forward(
176 input_ids,
177 attention_mask,
178 None,
179 past_key_values,
180 None,
181 pixel_values,
182 None,
183 None,
184 image_attention_mask,
185 None,
186 None,
187 None,
188 #interpolate_pos_encoding: Optional[bool] = False,
189 #return_dict: Optional[bool] = None,
190 )
191 #print(outputs)
192 logits = outputs.logits
File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)
File /opt/conda/lib/python3.10/site-packages/transformers/models/idefics/modeling_idefics.py:1485, in IdeficsForVisionText2Text.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, image_encoder_embeddings, perceiver_embeddings, image_attention_mask, labels, use_cache, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict)
1482 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1484 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1485 outputs = self.model(
1486 input_ids=input_ids,
1487 attention_mask=attention_mask,
1488 position_ids=position_ids,
1489 past_key_values=past_key_values,
1490 inputs_embeds=inputs_embeds,
1491 pixel_values=pixel_values,
1492 image_encoder_embeddings=image_encoder_embeddings,
1493 perceiver_embeddings=perceiver_embeddings,
1494 image_attention_mask=image_attention_mask,
1495 use_cache=use_cache,
1496 output_attentions=output_attentions,
1497 output_hidden_states=output_hidden_states,
1498 interpolate_pos_encoding=interpolate_pos_encoding,
1499 return_dict=return_dict,
1500 )
1502 hidden_states = outputs[0]
1503 logits = self.lm_head(hidden_states)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)
File /opt/conda/lib/python3.10/site-packages/transformers/models/idefics/modeling_idefics.py:1327, in IdeficsModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, image_encoder_embeddings, perceiver_embeddings, image_attention_mask, use_cache, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict)
1310 layer_outputs = self._gradient_checkpointing_func(
1311 vblock,
1312 decoder_layer,
(...)
1324 self.gated_cross_attn_layers,
1325 )
1326 else:
-> 1327 layer_outputs = vblock(
1328 decoder_layer,
1329 hidden_states,
1330 attention_mask=attention_mask,
1331 position_ids=position_ids,
1332 past_key_value=past_key_value,
1333 image_hidden_states=image_hidden_states,
1334 image_attention_mask=image_attention_mask,
1335 cross_attention_gate=cross_attention_gate,
1336 output_attentions=output_attentions,
1337 use_cache=use_cache,
1338 layer_idx=idx,
1339 cross_layer_interval=self.cross_layer_interval,
1340 gated_cross_attn_layers=self.gated_cross_attn_layers,
1341 )
1343 hidden_states = layer_outputs[0]
1345 if use_cache:
File /opt/conda/lib/python3.10/site-packages/transformers/models/idefics/modeling_idefics.py:1279, in IdeficsModel.forward.<locals>.vblock(main_block, hidden_states, attention_mask, position_ids, past_key_value, image_hidden_states, image_attention_mask, cross_attention_gate, output_attentions, use_cache, layer_idx, cross_layer_interval, gated_cross_attn_layers)
1277 if layer_idx % cross_layer_interval == 0:
1278 xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
-> 1279 outputs = xblock(
1280 hidden_states,
1281 attention_mask=attention_mask,
1282 image_hidden_states=image_hidden_states,
1283 image_attention_mask=image_attention_mask,
1284 cross_attention_gate=cross_attention_gate,
1285 output_attentions=output_attentions,
1286 use_cache=use_cache,
1287 past_key_value=None, # not implemented
1288 )
1289 hidden_states = outputs[0]
1291 layer_outputs = main_block(
1292 hidden_states,
1293 attention_mask=attention_mask,
(...)
1297 use_cache=use_cache,
1298 )
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)
File /opt/conda/lib/python3.10/site-packages/transformers/models/idefics/modeling_idefics.py:889, in IdeficsGatedCrossAttentionLayer.forward(self, hidden_states, attention_mask, image_hidden_states, image_attention_mask, cross_attention_gate, output_attentions, use_cache, past_key_value)
886 hidden_states = self.input_layernorm(hidden_states)
888 # Self Attention
--> 889 hidden_states, self_attn_weights, present_key_value = self.cross_attn(
890 hidden_states=hidden_states,
891 key_value_states=image_hidden_states,
892 attention_mask=image_attention_mask,
893 output_attentions=output_attentions,
894 )
895 hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
896 # Fill in zeros for cross_attention hidden_states of tokens attending to no images
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)
File /opt/conda/lib/python3.10/site-packages/transformers/models/idefics/modeling_idefics.py:658, in IdeficsAttention.forward(self, hidden_states, key_value_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
656 if attention_mask is not None:
657 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
--> 658 raise ValueError(
659 f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"
660 )
662 # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
663 # Reference: https://github.com/pytorch/pytorch/issues/112577.
664 if query_states.device.type == \"cuda\" and attention_mask is not None:
ValueError: Attention mask should be of size (1, 1, 1, 16), but is torch.Size([1, 1, 33, 16])"
}
I think that for now it would be much faster / easier to add a logits processor in this file for multimodal models. I am soon going to change the transformers
integration to use logits processors (plus add pipelines), and keep SequenceGenerator
for our internal experiments with sampling algorithms.
We can add models.idefics
and models.llava
for now, unless you have the equivalent of AutomodelForCausalLM
for multimodal models. These changes should be fairly quick.
I am happy to prioritize this.
Hi, I'll just chime in to say I'm very happy to see this is being discussed, and I will likely start experimenting with outlines once those vision models are added, sounds great!
FYI: Thanks to Outlines, Hugging Face TGI now supports constrained generation with e.g. JSON mode with vision LLMs like Idefics2, see the docs here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/visual_language_models#combining-vision-language-models-with-other-features
FYI: Thanks to Outlines, Hugging Face TGI now supports constrained generation with e.g. JSON mode with vision LLMs like Idefics2, see the docs here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/visual_language_models#combining-vision-language-models-with-other-features
On that page they have the link that I expect should go to the example: https://huggingface.co/docs/conceptual/guided-generation But I'm getting a 404. Is there an active page or did they take it down for some reason?
@houstonlucas I think the dead link is a mistake in the docs actually. it should probably go here: https://huggingface.co/docs/text-generation-inference/conceptual/guidance
Vision LLMs like Llava or Idefics are becoming more and more popular for processing both images and text conjointly. For example, Hugging Face will soon release Idefics2 and and Meta's Llama3 might also be multimodal.
Vision LLMs are also just LLMs that produce probabilities/logits over tokens, so my understanding is that they should also be compatible with outlines. The main challenge is that preprocessing is a bit more complicated and Hugging Face handles this with a
processor
class handles both text and images. These models also have atokenizer
that can be used during decoding/sampling. See e.g. idefics1: docs and usage examplesTwo questions:
What would be the best "hacky" way of using outlines with a vision LLM? Would subclassing the
SequenceGenerator
and changing the__call__
andsequence_generator
functions to adapt it to a specific model and processor make sense? Are there other places in the codebase that would be affected? https://github.com/outlines-dev/outlines/blob/main/outlines/generate/api.py#L15Are you open to integrating Vision LLMs more systematically into outlines?