WalBouss / GEM

[CVPR24] Official Implementation of GEM (Grounding Everything Module)
MIT License
86 stars 4 forks source link

No evaluation codes #5

Open duan-song opened 5 months ago

duan-song commented 5 months ago

Hi,

Thanks to the open source codes of GEM. But, I cannot reproduce the mIoU scores on Pascal VOC, Pascal Context, ADE20K, and OpenImages30K, which are reported in manuscript of CVPR 2024. I would like to ask if the author would be convenient to open the evaluation code? 微信图片_20240606231652

letitiabanana commented 5 months ago

Hi authors,

Thanks for the great work! I tried to add in evaluation code and got mIoU = 15 for VOC dataset, which deviate significantly from the number reported. I believe there must be some discrepency between my reimplementation and your code. Could you please released the code for evaluation?

Thanks a lot!

WalBouss commented 4 months ago

Hi, Thanks for your interest in our work and your feedback! I don’t have time to cleanup all the evaluation pipelines but here is the one for PascalVOC. I will try to push it as part of the repo whenever I have time:

Evaluation Pipeline for PascalVOC. Don’t forget to change the path to the PascalVOC dataset (root_path_voc). (You can download the dataset at http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar):

from tqdm import tqdm
import torch
import torch.nn.functional as F
from torchmetrics.classification import MulticlassJaccardIndex
from einops import rearrange
import gem

class ZeroShotSegmentation(torch.nn.Module):
    def __init__(self, model, tokenizer, model_name, patch_size=16,device='cpu'):
        super(ZeroShotSegmentation, self).__init__()

        self.model_name = model_name
        self.device = device

        self.gem_model = model
        self.gem_model.to(device)
        self.gem_model.eval()
        self.patch_size = patch_size
        self.tokenizer = tokenizer

    # @staticmethod
    def _get_text_embedding(self, classes: list):
        prompts = [f'a photo of a {cls}.' for cls in classes]

        tokenized_prompts = self.tokenizer(prompts).to(self.device)

        text_embedding = self.gem_model.model.encode_text(tokenized_prompts)
        text_embedding = F.normalize(text_embedding, dim=-1)
        return text_embedding.unsqueeze(0)

    def inference(self, image, text_embedding, mask_shape):
        B, _, H, W = image.shape
        # forward images
        feat_gem, feat_ori = self.gem_model.model.visual(image)
        feat_gem = F.normalize(feat_gem, dim=-1)

        # Patch/Text similarity
        logits_gem = 100.0 * feat_gem[:, 1:] @ text_embedding.transpose(1, 2)
        logits_gem = rearrange(logits_gem, 'b (h w) c -> b c h w', h=H // self.patch_size, w=W // self.patch_size)
        # Interpolate
        logits_gem = F.interpolate(logits_gem, size=mask_shape, mode='bilinear')

        # Segmentation prediction
        pred_gem = logits_gem.argmax(1) + 1

        return pred_gem, logits_gem

    @torch.no_grad()
    def eval_dataset(self, dataloader, classes, device):
        text_embedding = self._get_text_embedding(classes=classes[1:])  # remove background class

        threshold = 0.85
        metric_iou = MulticlassJaccardIndex(num_classes=len(classes), ignore_index=-1).to('cpu')

        for i, (image, mask) in enumerate(tqdm(dataloader)):
            image, mask = image.to(device), mask#.to(device)

            # pred_gem: [batch, W, H] | pred_logits_gem: [batch, num_class, W, H]
            pred_gem, pred_logits_gem = self.inference(image, text_embedding, mask.shape[-2:])

            # keep the highest probability for each pixel
            logits_soft_max_gem = pred_logits_gem.softmax(dim=1).max(dim=1)[0]  # 1 x H x W
            # clone argmaxed prediction
            pred_th_gem = pred_gem.clone()

            # apply threshold
            pred_th_gem[logits_soft_max_gem < threshold] = 0  # replace values under the threshold with the background class

            # Compute the IoU
            metric_iou(pred_th_gem.cpu(), mask)
            if i%20 == 0:
                print(metric_iou.compute().item() * 100)

        metric_th_gem = 100 * metric_iou.compute().item()
        print(f'mIoU: {metric_th_gem}')

        return metric_th_gem

def main(model_name, device, pretrained, patch_size=16, root_path_voc='', batch_size=1):
    # # Select Dataset
    if batch_size > 1:
        resize_mask = True
    else:
        resize_mask = False
    dataset = PascalVOC(root=root_path_voc, split='val',
                        transform=SegmentationTransforms((448, 448), resize_mask=resize_mask),
                        aug=False, only_image=False, only_mask=False, ignore_index=-1)

    test_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=8)

    # # Model
    model = gem.create_gem_model(model_name=model_name, pretrained=pretrained)
    tokenizer = gem.get_tokenizer(model_name=model_name)

    # # Evaluator
    zero_shot_evaluator = ZeroShotSegmentation(model=model, device=device, patch_size=patch_size,
                                               model_name=model_name, tokenizer=tokenizer)

    miou_list_cs = zero_shot_evaluator.eval_dataset(dataloader=test_loader,
                                                                   classes=list(dataset.CLASSES),
                                                                   device=device,
                                                                   )
    return miou_list_cs

if __name__ == '__main__':
    from segmentation_datasets.pascal_voc import PascalVOC, SegmentationTransforms

    patch_size = 16
    model_name = 'ViT-B-16-quickgelu'
    pretrained = 'metaclip_400m'
    root_path_voc = ‘/path/to/PascalVOC/'

    print('########################################')
    print(f'model: {model_name} | pretrained: {pretrained} ')
    print('########################################')

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    main(model_name=model_name, pretrained=pretrained, device=device, patch_size=patch_size, root_path_voc=root_path_voc)

Here is the dataset implementation:

from os.path import join
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.transforms import transforms

OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)

class PascalVOC(Dataset):
    CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
               'table', 'dog', 'horse', 'motorbike', 'person', 'plant', 'sheep', 'sofa', 'train', 'monitor')

    PALETTE = torch.tensor([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
                           [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
                           [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
                           [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
                           [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]], dtype=torch.uint8)

    def __init__(self,
                 root,
                 split='train',
                 transform=None,
                 only_image=False,
                 aug=True,
                 nclass=None,
                 only_mask=False,
                 split_file=None,
                 ignore_index=-1,
                 return_path=False):
        super(PascalVOC, self).__init__()
        self.nclass = nclass if nclass is not None else self.PALETTE.shape[0]
        self.only_image = only_image
        self.only_mask = only_mask
        self.split = split
        self.return_path = return_path
        self.ignore_index = ignore_index
        assert self.split in ['train', 'trainval', 'val'], f'{self.split} must be in ["train", "trainval", "val"]'
        self.split = 'trainaug' if aug and (self.split == 'train') else self.split
        self.root = join(root, 'VOCdevkit/VOC2012/') if split_file is None else root
        self.transform = transform

        self.anno_type = 'SegmentationClassAug' if aug else 'SegmentationClass'
        txt_file = join(self.root, split_file) if split_file is not None \
            else join(self.root, 'ImageSets', 'Segmentation', self.split + '.txt')

        self.samples = []
        with open(txt_file) as f:
            samples_tmp = f.readlines()
        samples_tmp = list(map(lambda elem: elem.strip(), samples_tmp))
        self.samples.extend(samples_tmp)

        samples_list = []
        self.image_files = []
        self.label_files = []
        for sample in self.samples:
            if split_file is not None:
                img = f'{str(sample)}.jpg'
                label = f'{str(sample)}.png'
            else:
                img = f'JPEGImages/{str(sample)}.jpg'
                label = f'{self.anno_type}/{str(sample)}.png'
            self.image_files.append(join(self.root, img))
            self.label_files.append(join(self.root, label))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):

        image_path = self.image_files[idx]
        label_path = self.label_files[idx]

        img, msk = Image.open(image_path).convert("RGB"), Image.open(label_path).convert("RGB")

        # if self.img_transform is not None:
        images, rgb_target = self.transform(img, msk)

        h, w = rgb_target.shape[1:]
        one_hot_seg_mask = self.ignore_index * torch.ones((h, w), dtype=torch.long)
        for color_idx in range(self.nclass):
            idx = (rgb_target == self.PALETTE[color_idx].unsqueeze(-1).unsqueeze(-1))
            valid_idx = (idx.sum(0) == 3)#.unsqueeze(0)
            one_hot_seg_mask[valid_idx] = color_idx

        if self.return_path:
            path_to_img_msk = {}
            path_to_img_msk["img_path"] = image_path
            path_to_img_msk["label_path"] = label_path
            return images, one_hot_seg_mask, path_to_img_msk

        return images, one_hot_seg_mask

class ToTensorMask(nn.Module):
    def __init__(self):
        super(ToTensorMask, self).__init__()

    def forward(self, mask):
        return torch.as_tensor(np.array(mask), dtype=torch.int64).permute(2, 0, 1)

class SegmentationTransforms(object):
    def __init__(self, size, img_transforms=None, resize_mask=False):
        self.img_transforms = img_transforms if img_transforms is not None else transforms.Compose([
            transforms.Resize(size=size, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD),
        ])
        self.mask_transforms = transforms.Compose([
            transforms.Resize(size=size) if resize_mask else nn.Identity(),
            ToTensorMask(),
        ])

    def __call__(self, image, mask):
        return self.img_transforms(image), self.mask_transforms(mask)

if __name__ == '__main__':
    root = '/path/to/PascalVOC/'
    dataset = PascalVOC(root=root, split='train', transform=SegmentationTransforms((448, 448), resize_mask=False),
                        aug=False, only_image=False, only_mask=False)

    test_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1)

    for img, mask in test_loader:
        print(img.shape)
        print(mask.shape)

You will also need to install the torchmetrics library via pip install torchmetrics Feel free to ask if you have any questions!

letitiabanana commented 4 months ago

Hi,

I am now able to reproduce your result for VOC. Thanks a lot for your reply!!

I am also interested in the different behaviors of models pre-training with single or multiple objectives, i.e. CLIP and BLIP. Do you mind sharing how your method can be implemented with BLIP as well?

Thanks again!