huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.04k stars 26.3k forks source link

LlavaNextProcessor bug in `_get_unpadded_features` #33261

Closed laurentd-lunit closed 1 week ago

laurentd-lunit commented 1 week ago

System Info

Who can help?

@zu

Information

Tasks

Reproduction

There is a typo in the following lines in LlavaNextProcessor as current_width and current_height are inverted which can cause errors due to miss match of image feature size computed by the processor and by the vision branch in LlavaNextForConditionalGeneration. I encountered this issue while running the following example script.

Here is a code snippet to reproduce the issue:

from transformers import LlavaNextProcessor
from transformers.models.llava_next.processing_llava_next import select_best_resolution
from transformers.models.llava_next.modeling_llava_next import unpad_image, get_anyres_image_grid_shape
import torch

POSSIBLE_RESOLUTIONS = [
    [
      336,
      672
    ],
    [
      672,
      336
    ],
    [
      672,
      672
    ],
    [
      1008,
      336
    ],
    [
      336,
      1008
    ]
]
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
HEIGHT = 500
WIDTH = 316
VISION_MODEL_INPUT_SIZE = 336
PATCH_SIZE = 14
PATCH_DIM = VISION_MODEL_INPUT_SIZE // PATCH_SIZE

# Reproduce pre-processing steps in the processor
height_best_resolution, width_best_resolution = select_best_resolution(
[HEIGHT, WIDTH], POSSIBLE_RESOLUTIONS
)
scale_height, scale_width = height_best_resolution // VISION_MODEL_INPUT_SIZE, width_best_resolution // VISION_MODEL_INPUT_SIZE
patches_height = VISION_MODEL_INPUT_SIZE // PATCH_SIZE
patches_width = VISION_MODEL_INPUT_SIZE // PATCH_SIZE
unpadded_features, newline_features = processor._get_unpadded_features(HEIGHT, WIDTH, patches_height, patches_width, scale_height, scale_width)
num_unpad_features_from_processor = unpadded_features

# Reproduce computation of unpadded features in the vision branch
# Equivalent to:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L676-L684
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
    (HEIGHT, WIDTH),
    POSSIBLE_RESOLUTIONS,
    VISION_MODEL_INPUT_SIZE,
)
unpad_features_from_vision = unpad_image(torch.randn(128, num_patch_height*PATCH_DIM, num_patch_width*PATCH_DIM), (HEIGHT, WIDTH))
num_unpad_features_from_vision = unpad_features_from_vision.shape[1] * unpad_features_from_vision.shape[2]

# Should be equal
assert num_unpad_features_from_processor == num_unpad_features_from_vision, f"Not equal: From processor: {num_unpad_features_from_processor}, from vision {num_unpad_features_from_vision}"

Expected behavior

No assertion error.

LysandreJik commented 1 week ago

cc @zucchini-nlp maybe