huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135k stars 27.01k forks source link

Mask2Former post-processing RLE #25486

Closed vjsrinivas closed 1 year ago

vjsrinivas commented 1 year ago

System Info

Who can help?

@amyeroberts I was trying to finetune Mask2Former with my own custom dataset, but I ran into an error when calling the Mask2FormerImageProcessor.post_process_instance_segmentation. I'm getting the following error when I set return_coco_annotation=True and a relatively low confidence threshold:

segmentation[pred_masks[j] == 1] = current_segment_id
TypeError: only integer tensors of a single element can be converted to an index

Could the issue be that the convert_segmentation_to_rle is called within the query loop rather than outside: https://github.com/huggingface/transformers/blob/0ebe7ae16076f727ac40c47f8f9167013c4596d8/src/transformers/models/mask2former/image_processing_mask2former.py#L1031 The segmentation tensor turns into a List[List], which might be causing the TypeError.

Information

Tasks

Reproduction

It's not practical to give you the custom training loop I have, but I recreated the situation with the ADE20K example for MaskFormer. Note that I stop this model's training within the first iteration and set the confidence threshold to 0.001 (error also occurs at 0.01, 0.1, etc). The error still occurs when I do a full epoch on my custom dataset.

from datasets import load_dataset
import torch
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import albumentations as A
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
import albumentations as A
from torch.utils.data import DataLoader 
from transformers import MaskFormerImageProcessor

dataset = load_dataset("scene_parse_150", "instance_segmentation")
data = pd.read_csv('./instanceInfo100_train.txt',
                   sep='\t', header=0, on_bad_lines='warn')
data.head(5)

id2label = {id: label.strip() for id, label in enumerate(data["Object Names"])}
print(id2label)

example = dataset['train'][1]
image = example['image']

seg = np.array(example['annotation'])
# get green channel
instance_seg = seg[:, :, 1]
instance_seg = np.array(example["annotation"])[:,:,1] # green channel encodes instances
class_id_map = np.array(example["annotation"])[:,:,0] # red channel encodes semantic category
class_labels = np.unique(class_id_map)

# create mapping between instance IDs and semantic category IDs
inst2class = {}
for label in class_labels:
    instance_ids = np.unique(instance_seg[class_id_map == label])
    inst2class.update({i: label for i in instance_ids})
print(inst2class)

processor = MaskFormerImageProcessor(reduce_labels=True, ignore_index=255, do_resize=False, do_rescale=False, do_normalize=False)

class ImageSegmentationDataset(Dataset):
    """Image segmentation dataset."""

    def __init__(self, dataset, processor, transform=None):
        """
        Args:
            dataset
        """
        self.dataset = dataset
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = np.array(self.dataset[idx]["image"].convert("RGB"))

        instance_seg = np.array(self.dataset[idx]["annotation"])[:,:,1]
        class_id_map = np.array(self.dataset[idx]["annotation"])[:,:,0]
        class_labels = np.unique(class_id_map)

        inst2class = {}
        for label in class_labels:
            instance_ids = np.unique(instance_seg[class_id_map == label])
            inst2class.update({i: label for i in instance_ids})

        # apply transforms
        if self.transform is not None:
            transformed = self.transform(image=image, mask=instance_seg)
            image, instance_seg = transformed['image'], transformed['mask']
            # convert to C, H, W
            image = image.transpose(2,0,1)

        if class_labels.shape[0] == 1 and class_labels[0] == 0:
            # Some image does not have annotation (all ignored)
            inputs = self.processor([image], return_tensors="pt")
            inputs = {k:v.squeeze() for k,v in inputs.items()}
            inputs["class_labels"] = torch.tensor([0])
            inputs["mask_labels"] = torch.zeros((0, inputs["pixel_values"].shape[-2], inputs["pixel_values"].shape[-1]))
        else:
          inputs = self.processor([image], [instance_seg], instance_id_to_semantic_id=inst2class, return_tensors="pt")
          inputs = {k: v.squeeze() if isinstance(v, torch.Tensor) else v[0] for k,v in inputs.items()}

        return inputs

ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255

# note that you can include more fancy data augmentation methods here
train_transform = A.Compose([
    A.Resize(width=512, height=512),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

train_dataset = ImageSegmentationDataset(dataset["train"], processor=processor, transform=train_transform)

def collate_fn(batch):
    pixel_values = torch.stack([example["pixel_values"] for example in batch])
    pixel_mask = torch.stack([example["pixel_mask"] for example in batch])
    class_labels = [example["class_labels"] for example in batch]
    mask_labels = [example["mask_labels"] for example in batch]
    return {"pixel_values": pixel_values, "pixel_mask": pixel_mask, "class_labels": class_labels, "mask_labels": mask_labels}

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

# %%
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k,v.shape)
  else:
    print(k,len(v))

# %%
from transformers import MaskFormerForInstanceSegmentation

# Replace the head of the pre-trained model
# We specify ignore_mismatched_sizes=True to replace the already fine-tuned classification head by a new one
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-ade",
                                                          id2label=id2label,
                                                          ignore_mismatched_sizes=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

running_loss = 0.0
num_samples = 0
for epoch in range(1):
  print("Epoch:", epoch)
  model.train()
  for idx, batch in enumerate(tqdm(train_dataloader)):
      # Reset the parameter gradients
      optimizer.zero_grad()

      # Forward pass
      outputs = model(
              pixel_values=batch["pixel_values"].to(device),
              mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
              class_labels=[labels.to(device) for labels in batch["class_labels"]],
      )

      # Backward propagation
      loss = outputs.loss
      loss.backward()

      batch_size = batch["pixel_values"].size(0)
      running_loss += loss.item()
      num_samples += batch_size

      if idx % 100 == 0:
        print("Loss:", running_loss/num_samples)

      # Optimization
      optimizer.step()

      if idx == 1:
         break

########## SAMPLE VALIDATION LOOP: ############

processor = MaskFormerImageProcessor()
model.eval()
with torch.no_grad():
    for idx, batch in enumerate(tqdm(train_dataloader)):
        outputs = model(
              pixel_values=batch["pixel_values"].to(device)
        )

        coco_out = processor.post_process_instance_segmentation(outputs, threshold=0.001, return_coco_annotation=True)
        print(coco_out)

Expected behavior

Mask2FormerImageProcessor.post_process_instance_segmentation not erroring out regardless of model segmentation output.

amyeroberts commented 1 year ago

Hi @vjsrinivas, thanks for reporting!

I've opened #25497 which should resolve this issue

vjsrinivas commented 1 year ago

@amyeroberts thanks for the quick reply! Do we pip install from the github project for these kinds of hotfixes?

amyeroberts commented 1 year ago

@vjsrinivas Yes, once the PR is merged in, you'll need to install from source to have the current changes in main. They will be included in the next version release.