huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.43k stars 25.67k forks source link

LLaVA `torch.compile` implementation #29891

Open sheepymeh opened 3 months ago

sheepymeh commented 3 months ago

Feature request

As per #28981, LLaVA is planned to receive torch.compile support. Seeing to the fact that LLaVA is composed of a vision tower and a LLM, both of which can be separately compiled with fullgraph=True (after support has been added, which is not the case for Mistral), it seems much easier to compile both parts separately as well.

Motivation

The _merge_input_ids_with_image_features function that connects the two parts is difficult to compile as PyTorch has yet to add support for many of the functions used that require dynamic input sizes, which are necessary here as the number of input image tokens is subject to change.

Your contribution

I'd love to try submitting a PR if possible but I'm not sure what the best way to do so is given the current circumstances.

ArthurZucker commented 3 months ago

Thanks! Mistral should be fairly easy to implement, just follow the updates done to Llama! 🤗 Generate will be refactored to support compile soon!

FYI @gante and @zucchini-nlp

sheepymeh commented 3 months ago

I've checked that the following code runs:

import os
from functools import partial

import requests
import torch
from PIL import Image

from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor, StaticCache

os.environ["TOKENIZERS_PARALLELISM"] = "true"

with torch.inference_mode():
    processor = LlavaNextProcessor.from_pretrained("llava-mistral")
    model = LlavaNextForConditionalGeneration.from_pretrained("llava-mistral", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()

    static_cache = partial(StaticCache, dtype=torch.float16)
    model.language_model._setup_cache(static_cache, max_batch_size=1, max_cache_len=4096)
    model.language_model.compile(fullgraph=True)
    model.vision_tower.compile(fullgraph=True)

    url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
    image = Image.open(requests.get(url, stream=True).raw)
    prompt = "[INST] <image>"

    inputs = processor(prompt, image, return_tensors="pt").to(model.device)
    output = model(**inputs,)

print(output)
gante commented 3 months ago

Hey @sheepymeh 👋

Given that the individual models are compileable (according to your script above), the next step would be to rewrite the logic in between the language model and the vision tower to be compilable as well such that model.forward can be compiled. It shouldn't require API changes, so we're welcoming contributions 💪

The process is very model-dependant, often requiring padding or clever manipulations to enable it. My suggestion would be to iteratively call the compiled forward pass, see where it crashes, rewrite, and repeat until it runs. Then, after the whole thing compiles, do a second pass to optimize performance (with the aid of a profiler) -- the compiled forward should be significantly faster than the uncompiled one, after 2-3 warmup forward passes.

Looking forward to the PR! If you get stuck, let us know :)

sheepymeh commented 3 months ago

Thank you for your suggestions. I'm currently working on a PR and my current progress is on creating fixed-length tensors for everything. For example, final_embedding would be initialized to a shape of (batch, max_position_embeddings, embed_dim). However, this would create a lot of unnecessary and wasteful padding, which would be passed into the LLM. How could I mitigate this or is this an acceptable compromise? For reference, the max_position_embeddings for llava-1.6-mistral is 32768, which would be padded a lot of the time.

sheepymeh commented 3 months ago

For reference, here is the (very unoptimized) version I'm working on:

Code ```python def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): num_images, num_image_patches, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape left_padding = torch.any(input_ids[:, -1] == torch.tensor(self.pad_token_id)).to(torch.int8) # CHANGE: multiply instead of using booleans # 1. Create a mask to know where special image tokens are special_image_token_mask = input_ids == self.config.image_token_index num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) # Compute the maximum embed dimension max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length max_pos_embed = self.language_model.config.max_position_embeddings # CHANGE: automatically use maximum position embedidngs text_tokens = input_ids != self.config.image_token_index # CHANGE: use a boolean mask instead of torch.where # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged image-text sequence. # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. # `torch.cumsum` computes how each image token shifts subsequent text token positions. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] new_token_positions += nb_image_pad[:, None] * left_padding # offset for left padding text_to_overwrite = new_token_positions.clone() last_token = new_token_positions.max(dim=-1).values + 1 text_to_overwrite[~text_tokens] = -1 # set to -1 to place image tokens in "waste" element # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = inputs_embeds.device text_tokens, text_to_overwrite = ( text_tokens.to(target_device), text_to_overwrite.to(target_device), ) attention_mask = attention_mask.to(target_device) # 3. Create the full embedding # CHANGE: pad to max seq len by default and create a "waste" element at the end final_embedding = torch.zeros( batch_size, max_pos_embed + 1, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device ) final_attention_mask = torch.zeros( batch_size, max_pos_embed + 1, dtype=attention_mask.dtype, device=inputs_embeds.device ) if labels is not None: final_labels = torch.full( (batch_size, max_pos_embed + 1), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device ) # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features # I'm not sure how I could do this in a vectorized way for batch in range(batch_size): final_embedding[batch, text_to_overwrite[batch]] = inputs_embeds[batch] final_attention_mask[batch, text_to_overwrite[batch]] = attention_mask[batch] if labels is not None: final_labels[batch, text_to_overwrite[batch]] = labels[batch] pad_mask = torch.arange(max_pos_embed + 1, device=target_device).unsqueeze(dim=0).repeat(batch_size, 1) >= last_token.unsqueeze(-1) final_embedding[pad_mask] = self.pad_token_id final_attention_mask[pad_mask] = 0 # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling image_to_overwrite = torch.all(final_embedding == 0, dim=-1) image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) # if image_to_overwrite.sum() != image_features.shape[:-1].numel(): # raise ValueError( # f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" # f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." # ) image_features = image_features.contiguous().reshape(-1, embed_dim).to(target_device) print(image_to_overwrite) final_embedding[image_to_overwrite] = image_features final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) indices_to_mask = new_token_positions[batch_indices, pad_indices] final_embedding[batch_indices, indices_to_mask] = 0 if labels is None: final_labels = None # CHANGE: incompatible with torch.compile; try to remove the additional tokens final_embedding, final_attention_mask, final_labels, position_ids = ( final_embedding[:, :max_pos_embed], final_attention_mask[:, :max_pos_embed], final_labels[:, :max_pos_embed] if final_labels is not None else None, position_ids[:, :max_pos_embed], ) return final_embedding, final_attention_mask, final_labels, position_ids ```
ArthurZucker commented 3 months ago

Thank you for your suggestions. I'm currently working on a PR and my current progress is on creating fixed-length tensors for everything. For example, final_embedding would be initialized to a shape of (batch, max_position_embeddings, embed_dim). However, this would create a lot of unnecessary and wasteful padding, which would be passed into the LLM. How could I mitigate this or is this an acceptable compromise? For reference, the max_position_embeddings for llava-1.6-mistral is 32768, which would be padded a lot of the time.

When you generate, you mostly want the decoding part, when the input ids is only 1 token, to be fast. This should not require much changes as the final embeddings would be of the same size as the input itds.

One thing that could help with this is actually to pre-process the strings, in order to make sure that the input of the model is the only thing that changes in terms of shapes. So instead of creating the final embedding, you replace "Hey <image> is this nice?" with Hey <image><image>.......<image> is this nice?" in the processor. That way the embedding created is of the correct shape.

sheepymeh commented 3 months ago

This would break compatability with previous code and the original LLaVA codebase completely. Is it advisable to maintain two code paths in the processor (one with a single token and one with multiple per image)?

sheepymeh commented 3 months ago

Thank you for the help so far. I've managed to implement @ArthurZucker's suggestion successfully in the preprocessor. However, the unpad_image function:

def unpad_image(tensor, original_size):
    """
    Unpads a PyTorch tensor of a padded and resized image.

    Args:
        tensor (`torch.Tensor`):
            The image tensor, assumed to be of shape (num_channels, height, width).
        original_size (`tuple`):
            The original size of the image (height, width).

    Returns:
        `torch.Tensor`: The unpadded image tensor.
    """
    original_height, original_width = original_size
    current_height, current_width = tensor.shape[1:]

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
        unpadded_tensor = tensor[:, padding : current_height - padding, :]
    else:
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
        unpadded_tensor = tensor[:, :, padding : current_width - padding]

    return unpadded_tensor

requires original_size to be passed into the model through the preprocessor, and it cannot be in the form of a PyTorch tensor as slicing by tensor values is unsupported:

torch._dynamo.exc.Unsupported: Dynamic slicing on data-dependent value is not supported

from user code:
   File "/home/sheepymeh/transformers/src/transformers/models/llava_next/modeling_llava_next.py", line 427, in unpad_image
    unpadded_tensor = tensor[:, padding : current_height - padding, :]

Is it possible to somehow pass these values in as a list of ints? Trying that causes this error when using the .to() function:

Traceback (most recent call last):
  File "/home/sheepymeh/transformers/test.py", line 14, in <module>
    inputs = inputs.to("cuda:0")
  File "/home/sheepymeh/transformers/src/transformers/feature_extraction_utils.py", line 229, in to
    if torch.is_floating_point(v):
TypeError: is_floating_point(): argument 'input' (position 1) must be Tensor, not list
ArthurZucker commented 1 month ago

Not super-sure, but Llava does not require this, only Llava next.

zucchini-nlp commented 1 month ago

Yes, it's only Llava-next. Btw, not sure how helpul is this info but in the linked PR we had to convert "image sizes" to a list inside "modeling.py", because there were mismatches in the way resolutions are calculated in the image processor vs in the modeling side.

sheepymeh commented 1 month ago

Thanks for the inputs. Unfortunately it's not possible to convert tensors to lists in a compiled PyTorch model, thus we would need to implement the vectorized image_size first.