huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.7k stars 26.44k forks source link

Incorrect docstring of `get_anyres_image_grid_shape` #31588

Closed DarkLight1337 closed 1 month ago

DarkLight1337 commented 3 months ago

Upon inspecting the source code, the image_size tuple should be in the form (height, width) instead of (width, height)

https://github.com/huggingface/transformers/blob/aab08297903de0ae39d4a6d87196b5056d76f110/src/transformers/models/llava_next/modeling_llava_next.py#L52

amyeroberts commented 3 months ago

@DarkLight1337 Would you like to open a PR to fix this?

cc @zucchini-nlp To confirm, as I think this was raised elsewhere and there's a double inversion which happens (?)

DarkLight1337 commented 3 months ago

After looking at the code a bit more, now I am more confused. It seems that LLaVA-NeXT model treats it as (width, height) but still works correctly. Or is that just incorrect variable naming?

zucchini-nlp commented 3 months ago

Hey!

Yes, this issue has been noticed by several people and I can confirm that our implementation matched perfectly with the LLaVa-NeXT. Yes, there are naming discrepancies between the two, which is confusing but it all comes from the way it's done in the original repo.

But if we try to get the correct way, the way is should be as I understand, then there is a "bug" in both implementations. Because LLaVa-NeXT treat is as (width, height) up to some point in modeling where the order is swapped back to (height, width) (they permute image to "height, width" and not "width, height").

I raised a question to LLaVa authors a week ago and didn't get a reply yet. So I wouldn't change anything in transformers until authors confirm it's a bug and not an intended thing. Only thing I can do is add a small comment in code clarifying the point. I could align naming with LLaVa-NeXT repo by using (width, height) order in processing, but that would raise more questions about why we use incorrect order while image-processing

Hope this clarifies it a bit ;)

DarkLight1337 commented 3 months ago

Thanks for the clarification! Let's wait until the authors respond then.

yinsong1986 commented 2 months ago

Hi @DarkLight1337 @zucchini-nlp

I also found this potential issue here https://github.com/huggingface/transformers/issues/31529

Sharing some of my observations as below:

FYI: I think in HF implementation, it transforms the image to np.array https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/image_processing_llava_next.py#L703, before dividing into patches, which the image shape is always (height, width). So this line the image input should be (height, width). also, https://github.com/huggingface/transformers/blob/12b1620e615592fbf099d4ec44af7b9f2d1b48aa/src/transformers/models/llava_next/modeling_llava_next.py#L656 should return num_patch_height, num_patch_width.

In comparison, the original implementation, https://github.com/haotian-liu/LLaVA/blob/c121f0432da27facab705978f83c4ada465e46fd/llava/mm_utils.py#L143, it uses the original PIL image shape (width, height) before dividing into patches. After dividing into patches, the patches shape becomes (height, width). To the best of my knowledge, this is the reason why you see it swap from (width, height) to (height, width) afterwards.

Cheers!

zucchini-nlp commented 2 months ago

@yinsong1986 wow, thanks for digging into this! That totally makes sense, I was assuming for some reason that Llava impl also uses arrays.

In that case this line also needs to be adjusted accordingly. https://github.com/huggingface/transformers/blob/aab08297903de0ae39d4a6d87196b5056d76f110/src/transformers/models/llava_next/modeling_llava_next.py#L656

Would you like to make a PR for this? It will also require to check logits equivalence between HF impl and LLava impl, just to make sure we're not inserting hidden bugs