huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.14k stars 27.05k forks source link

Malfunctioning of transformers.image_transforms.pad #34775

Open John6666cat opened 18 hours ago

John6666cat commented 18 hours ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

from transformers.image_transforms import pad
import numpy as np
import torch
from PIL import Image

# Example image as a NumPy array
image = np.random.rand(224, 224, 3)  # Height x Width x Channels
image_pil = np.array(Image.fromarray(image, 'RGB')) # Open with PIL and save
image_uint8 = (image * 255.0).astype(np.uint8)

# Define padding: ((before_height, after_height), (before_width, after_width))
padding = ((0, 0), (112, 112))  # Pads width to make it 448

# Apply padding
padded_image = pad(image, padding=padding)
padded_image_pil = pad(image_pil, padding=padding)
padded_image_uint8 = pad(image_uint8, padding=padding)
print("Original Image Shape:", image.shape)
print("Padded Image Shape:", padded_image.shape)
print("Padded Image Shape (PIL):", padded_image_pil.shape)
print("Padded Image Shape (uint8):", padded_image_uint8.shape)

image_torch = torch.tensor(image).permute(2, 0, 1).unsqueeze(0)
padded_image_torch = torch.tensor(padded_image).permute(2, 0, 1).unsqueeze(0)
padded_image_pil_torch = torch.tensor(padded_image_pil).permute(2, 0, 1).unsqueeze(0)
padded_image_uint8_torch = torch.tensor(padded_image_uint8).permute(2, 0, 1).unsqueeze(0)

print("Original Image Shape (Torch):", image_torch.shape)
print("Padded Image Shape (Torch):", padded_image_torch.shape)
print("Padded Image Shape (PIL) (Torch):", padded_image_pil_torch.shape)
print("Padded Image Shape (uint8) (Torch):", padded_image_uint8_torch.shape)

# Save images
original_im = Image.fromarray(image, 'RGB')
padded_im = Image.fromarray(padded_image, 'RGB')
padded_im_pil = Image.fromarray(padded_image_pil, 'RGB')
padded_im_uint8 = Image.fromarray(padded_image_uint8, 'RGB')
original_im.save("_pad_original.png") # normal
padded_im.save("_pad_padded.png") # strange
padded_im_pil.save("_pad_padded_pil.png") # normal
padded_im_uint8.save("_pad_padded_uint8.png") # relatively normal

Expected behavior

After receiving a report on the Hugging Face forum that the padding in the transformers library was behaving strangely, I investigated and found the approximate cause. It seems that the pad function in numpy returns strange results when it receives an ndarray that is not uint8. As a simple workaround, there is a method of converting it to Pillow Image once, but this method is dependent on Pillow. If the library converts it to uint8 on its own, it may be a little troublesome to judge the numerical range of the image. I opened an issue instead of a PR because I couldn't think of a good implementation.

Reference

https://discuss.huggingface.co/t/clipvisionmodel-padding-problem/124187 A post that helps me find a problem https://discord.com/channels/879548962464493619/1301002234963820604 Investigating preprocessor bugs

Dependencies

transformers==4.46.2
torch==2.4.0
numpy<2

Demo

https://huggingface.co/spaces/John6666/transformers_padding_bug_test

qubvel commented 11 hours ago

Hi @John6666cat, thanks for submitting the issue with a detailed example and Demo! I'm trying to reproduce it on my side with the real image but padding seems working fine. Here is a colab with the code.

It looks like the issue is not in the pad or uint8/float dtype issue. I suppose the problem might be related to converting numpy image to PIL image.

I also wrote this simple test to ensure padding is correct and equal to 0 for each padded border:

import numpy as np
from transformers.image_transforms import pad

# Example image as a NumPy array
image = np.random.rand(224, 224, 3)  # Height x Width x Channels

# Define padding: ((before_height, after_height), (before_width, after_width))
padding = ((0, 0), (112, 112))  # Pads width to make it 448

# Apply padding
padded_image = pad(image, padding=padding)

# check padding
assert padded_image.shape == (224, 448, 3), f"Expected padded image shape to be (224, 448, 3) but got {padded_image.shape}"
assert padded_image[:, :112].sum() == 0, "Expected padding to be 0"
assert padded_image[:, -112:].sum() == 0, "Expected padding to be 0"

Please let me know if you investigate it further.

John6666cat commented 4 hours ago

Hi @qubvel, thank you for confirming.😀 It seems that there is a problem with the conversion to PIL.Image. When I cast the image read by PIL back to float32 and passed it through padding, I was able to create an ndarray that PIL could not save normally. However, in your Colab, there is no problem with the exact same method...

If there is a difference, it would be only with Image.save() and IPython.display? Perhaps there is some bug in Image.save() when dealing with RGB pixels that are not uint8.

from transformers.image_transforms import pad
import numpy as np
from PIL import Image

def to_float(image):
  return image.astype("float32") / 255.0 - 0.5

def to_uint8(image):
  return ((image + 0.5) * 255.0).astype("uint8")

# Example image as a NumPy array
image = np.random.rand(224, 224, 3)  # Height x Width x Channels
image_pil = to_float(np.array(Image.fromarray(image, 'RGB'))) # Open with PIL and save WITH to_float!

# Define padding: ((before_height, after_height), (before_width, after_width))
padding = ((0, 0), (112, 112))  # Pads width to make it 448

# Apply padding
padded_image = pad(image, padding=padding)
padded_image_pil = pad(image_pil, padding=padding)
print("Original Image Shape:", image.shape)
print("Padded Image Shape:", padded_image.shape)
print("Padded Image Shape (PIL):", padded_image_pil.shape)

# Save images
original_im = Image.fromarray(image, 'RGB')
padded_im = Image.fromarray(padded_image, 'RGB')
padded_im_pil = Image.fromarray(padded_image_pil, 'RGB')
print("Original", image.dtype)
print("Padded", padded_image.dtype)
print("Padded (PIL)", padded_image.dtype)
original_im.save("_pad_original.png") # normal
padded_im.save("_pad_padded.png") # strange
padded_im_pil.save("_pad_padded_pil.png") # strange

To summarize so far, if the processing is completed using numpy alone, there is no problem. Also, if you save all the input images to PIL.Image.save() in the first stage, as in Gradio, they will be processed as uint8, and you will not encounter this problem in that context. I use Gradio a lot, so I never noticed this before... I haven't tried torchvision or cv2, but depending on the timing of the conversion, it may not be good.

Also, I don't know if these padded float32 images are being recognized correctly when handled by transformers... There have been reports of a phenomenon where the accuracy drops by several tens of percent when passing through a preprocessor, so in the worst case, it may be related to that. I'm looking for it, but I haven't found the specific image yet.

Edit: Reports. https://github.com/fpgaminer/joycaption/blob/main/scripts/batch-caption.py#L192 https://github.com/xorbitsai/inference/issues/2493 https://github.com/Valdanitooooo/chat_with_qwen2_vl_test