huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.67k stars 25.76k forks source link

Providing several prompt_images and prompt_masks to seggpt leads to RuntimeError #30196

Closed MSchnei closed 3 months ago

MSchnei commented 3 months ago

System Info

Who can help?

@amyeroberts @EduardoPach @zucchini-nlp

Information

Tasks

Reproduction

I tried providing several prompt_images and prompt_masks to seggpt, using this toy-modification of the provided example:

from transformers import SegGptImageProcessor, SegGptModel
from PIL import Image
import requests

image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"

image_input = Image.open(requests.get(image_input_url, stream=True).raw)
image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")

checkpoint = "BAAI/seggpt-vit-large"
model = SegGptModel.from_pretrained(checkpoint)
image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
inputs = image_processor(images=image_input, prompt_images=[image_prompt, image_prompt], prompt_masks=[mask_prompt, mask_prompt], return_tensors="pt")
outputs = model(**inputs)
list(outputs.last_hidden_state.shape)

Expected behavior

When trying to run the above code, I get the following error:

Traceback (most recent call last):
  File "/Users/marianschneider/git/visprompt/.venv/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-9c7245202514>", line 17, in <module>
    outputs = model(**inputs)
  File "/Users/marianschneider/git/visprompt/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/marianschneider/git/visprompt/.venv/lib/python3.9/site-packages/transformers/models/seggpt/modeling_seggpt.py", line 804, in forward
    pixel_values = torch.cat((prompt_pixel_values, pixel_values), dim=2)
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 2 but got size 1 for tensor number 1 in the list.

I would have expected the code to work:

The code works fine if I also provide several input images, like so:

inputs = image_processor(images=[image_input, image_input], prompt_images=[image_prompt, image_prompt], prompt_masks=[mask_prompt, mask_prompt], return_tensors="pt")

but this means I need to duplicate my input image and also get several predicted masks for the same image (which one to pick?), so it's not ideal.

Delving deeper into the error, this results from the fact that we are trying to concatenate the input image and the prompt image in the height direction here. However, if I provide more than one prompt image, the batch size will be larger for prompt images than for the input images. Consequently, the concatenation fails because of the dimension mismatch in dimension 0 (batch dimension).

EduardoPach commented 3 months ago

@MSchnei if you take a closer look at the shapes of your inputs you'll see that your inputs.pixel_values have a different shape when compared to your prompts.

You should've done:

inputs = image_processor(images=[image_input, mage_input], prompt_images=[image_prompt, image_prompt], prompt_masks=[mask_prompt, mask_prompt], return_tensors="pt")
MSchnei commented 3 months ago

Thank you for the fast reply, @EduardoPach

Yes, I am referring to the same workaround in the issue description as well. But my question was more: is this really what we want here? As mentioned above, this means users need to duplicate their input image and also get several predicted masks for the same image.

Maybe that's not as big an issue as I thought. Running this over a couple of test images, the predicted masks for the same image are always identical, so I assume I can always safely pick the first one?

This is one test I ran:

import torch
from datasets import load_dataset
from transformers import SegGptImageProcessor, SegGptForImageSegmentation
from PIL import Image

checkpoint = "BAAI/seggpt-vit-large"
image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
model = SegGptForImageSegmentation.from_pretrained(checkpoint)

dataset_id = "EduardoPacheco/FoodSeg103"
ds = load_dataset(dataset_id, split="train")
# Number of labels in FoodSeg103 (not including background)
num_labels = 103

image_input = ds[4]["image"]
ground_truth = ds[4]["label"]
image_prompt = ds[29]["image"]
mask_prompt = ds[29]["label"]

inputs = image_processor(
    images=[image_input, image_input],
    prompt_images=[image_prompt, image_prompt],
    prompt_masks=[mask_prompt, mask_prompt],
    num_labels=num_labels,
    return_tensors="pt"
)

with torch.no_grad():
    outputs = model(**inputs)

target_sizes = image_input.size[::-1]
mask = image_processor.post_process_semantic_segmentation(outputs, [target_sizes, target_sizes], num_labels=num_labels)

assert torch.equal(mask[0], mask[1])
EduardoPach commented 3 months ago

I'd say that is the way we want it to be because in this way we have images and prompts paired and yeah it's expected that if you're running the same image with the same prompts just in a batched manner you'd get the same output.

If you're interested on using one image with several prompts I would suggest you do something like:

image_prompts = [image_prompt_n, ...., image_prompt_n]
mask_prompts = [...]
image_input = [image_input] * len(image_prompts)

....

and then you can try to use the feature_ensamble from SegGptForImageSegmentation see the docs

I plan on writing a guide to make the use of SegGPT easier as it's quite a different model

MSchnei commented 3 months ago

Thanks for the pointer to the feature_ensemble keyword. Can confirm that this works:

inputs = image_processor(
    images=[image_input, image_input],
    prompt_images=[image_prompt, image_prompt],
    prompt_masks=[mask_prompt, mask_prompt],
    num_labels=num_labels,
    return_tensors="pt",
    feature_ensemble=True
)

with torch.no_grad():
    outputs = model(**inputs)

Closing this issue because the workaround of passing several images works well enough.