Closed MSchnei closed 7 months ago
@MSchnei if you take a closer look at the shapes of your inputs you'll see that your inputs.pixel_values
have a different shape when compared to your prompts.
You should've done:
inputs = image_processor(images=[image_input, mage_input], prompt_images=[image_prompt, image_prompt], prompt_masks=[mask_prompt, mask_prompt], return_tensors="pt")
Thank you for the fast reply, @EduardoPach
Yes, I am referring to the same workaround in the issue description as well. But my question was more: is this really what we want here? As mentioned above, this means users need to duplicate their input image and also get several predicted masks for the same image.
Maybe that's not as big an issue as I thought. Running this over a couple of test images, the predicted masks for the same image are always identical, so I assume I can always safely pick the first one?
This is one test I ran:
import torch
from datasets import load_dataset
from transformers import SegGptImageProcessor, SegGptForImageSegmentation
from PIL import Image
checkpoint = "BAAI/seggpt-vit-large"
image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
model = SegGptForImageSegmentation.from_pretrained(checkpoint)
dataset_id = "EduardoPacheco/FoodSeg103"
ds = load_dataset(dataset_id, split="train")
# Number of labels in FoodSeg103 (not including background)
num_labels = 103
image_input = ds[4]["image"]
ground_truth = ds[4]["label"]
image_prompt = ds[29]["image"]
mask_prompt = ds[29]["label"]
inputs = image_processor(
images=[image_input, image_input],
prompt_images=[image_prompt, image_prompt],
prompt_masks=[mask_prompt, mask_prompt],
num_labels=num_labels,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
target_sizes = image_input.size[::-1]
mask = image_processor.post_process_semantic_segmentation(outputs, [target_sizes, target_sizes], num_labels=num_labels)
assert torch.equal(mask[0], mask[1])
I'd say that is the way we want it to be because in this way we have images and prompts paired and yeah it's expected that if you're running the same image with the same prompts just in a batched manner you'd get the same output.
If you're interested on using one image with several prompts I would suggest you do something like:
image_prompts = [image_prompt_n, ...., image_prompt_n]
mask_prompts = [...]
image_input = [image_input] * len(image_prompts)
....
and then you can try to use the feature_ensamble
from SegGptForImageSegmentation
see the docs
I plan on writing a guide to make the use of SegGPT
easier as it's quite a different model
Thanks for the pointer to the feature_ensemble
keyword. Can confirm that this works:
inputs = image_processor(
images=[image_input, image_input],
prompt_images=[image_prompt, image_prompt],
prompt_masks=[mask_prompt, mask_prompt],
num_labels=num_labels,
return_tensors="pt",
feature_ensemble=True
)
with torch.no_grad():
outputs = model(**inputs)
Closing this issue because the workaround of passing several images
works well enough.
System Info
transformers
version: 4.39.3Who can help?
@amyeroberts @EduardoPach @zucchini-nlp
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I tried providing several
prompt_images
andprompt_masks
to seggpt, using this toy-modification of the provided example:Expected behavior
When trying to run the above code, I get the following error:
I would have expected the code to work:
prompt_images
andprompt_masks
here, which isImageInput
and allows for a list of PIL images as the inputThe code works fine if I also provide several input
images
, like so:but this means I need to duplicate my input image and also get several predicted masks for the same image (which one to pick?), so it's not ideal.
Delving deeper into the error, this results from the fact that we are trying to concatenate the input image and the prompt image in the height direction here. However, if I provide more than one prompt image, the batch size will be larger for prompt images than for the input images. Consequently, the concatenation fails because of the dimension mismatch in dimension 0 (batch dimension).