goel-shashank / CyCLIP

112 stars 14 forks source link

Reproducing zero-shot retrieval experiments with CyCLIP #9

Closed ytaek-oh closed 6 months ago

ytaek-oh commented 6 months ago

Dear authors First of all, thank you for the wonderful project and for sharing codes and checkpoints

While reproducing the zero-shot retrieval experiments from the Table 6 in your main paper, I observed a huge gap between the reported results in the paper and my reproduced results.

For clarity, here are the results of the zero-shot retrieval performance I obtained: COCO dataset

Flickr dataset

The results seem to align closely with the reported performance only in the case of Image-to-Text retrieval on Flickr30k.

I used the CyCLIP checkpoint provided via Google Drive and conducted tests using the Karpathy test split of the COCO and Flickr datasets.

I adapted the code from this repo for retrieval experiments. I hope you can take a quick review on the codes below and help me identify any potential issues.

Additionally, it would be immensely helpful if you could share your codes used for the zero-shot retrieval experiments.

Best regards,


Usage: python test_retrieval.py --dataset coco # or flickr, where

test_retrieval.py:

import argparse

import open_clip
import torch

from src.retrieval import get_loader_image, get_loader_text

def compute_retrieval(similarity_scores, txt2img, img2txt):
    # comput text -> image
    t2i_similarity_score = similarity_scores.t()
    t2i_ranks = torch.zeros(t2i_similarity_score.shape[0])

    for index, score in enumerate(t2i_similarity_score):
        inds = torch.argsort(score, descending=True)
        t2i_ranks[index] = torch.where(inds == txt2img[index])[0][0]
        print(
            'Evaluating batch {}/{}, {}'.format(
                index, t2i_similarity_score.shape[0], t2i_ranks[index]
            ),
            end="\r"
        )

    # Compute metrics
    tr1 = 100.0 * len(torch.where(t2i_ranks < 1)[0]) / len(t2i_ranks)
    tr5 = 100.0 * len(torch.where(t2i_ranks < 5)[0]) / len(t2i_ranks)
    tr10 = 100.0 * len(torch.where(t2i_ranks < 10)[0]) / len(t2i_ranks)
    t2i_report_dict = {"r1": tr1, "r5": tr5, "r10": tr10}

    # comput image -> text
    i2t_similarity_score = similarity_scores
    i2t_ranks = torch.zeros(i2t_similarity_score.shape[0])
    for index, score in enumerate(i2t_similarity_score):
        print('Evaluating batch {}/{}'.format(index, i2t_similarity_score.shape[0]), end="\r")
        inds = torch.argsort(score, descending=True)
        # Score
        rank = 1e10
        for i in img2txt[index]:
            tmp = torch.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        i2t_ranks[index] = rank

    # Compute metrics
    ir1 = 100.0 * len(torch.where(i2t_ranks < 1)[0]) / len(i2t_ranks)
    ir5 = 100.0 * len(torch.where(i2t_ranks < 5)[0]) / len(i2t_ranks)
    ir10 = 100.0 * len(torch.where(i2t_ranks < 10)[0]) / len(i2t_ranks)
    i2t_report_dict = {"r1": ir1, "r5": ir5, "r10": ir10}
    return t2i_report_dict, i2t_report_dict

def get_image_feature(model, data_loader):
    image_features = []
    for batch_idx, batch in enumerate(data_loader):
        print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r")
        images, _ = batch
        image_emb = model.encode_image(images.cuda())  # embed with image encoder
        image_features.append(image_emb.detach().cpu())
    image_features = torch.cat(image_features, 0)

    print('Done image feature extract.')
    print(image_features.shape)

    # normalized features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    return image_features

def get_text_feature(model, data_loader):
    text_features = []
    for batch_idx, batch in enumerate(data_loader):
        print('Evaluating batch {}/{}'.format(batch_idx, len(data_loader)), end="\r")
        text = batch.squeeze()
        text_emb = model.encode_text(text.cuda())
        text_features.append(text_emb.detach().cpu())

    text_features = torch.cat(text_features, 0)
    print('Done text feature extract.')
    print(text_features.shape)

    # normalized features
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    return text_features

def main(args):
    pretrained = "/home/appuser/.cache/torch/hub/CyCLIP/cc3m/CyCLIP.pt"
    model, _, transform = open_clip.create_model_and_transforms(
        "RN50", pretrained=pretrained, device="cuda"
    )
    model = model.eval().cuda()

    if args.dataset == "coco":
        # karpathy split
        ann_file = "/home/appuser/datasets/coco/coco_karpathy_test.json"
        data_root = "/home/appuser/datasets/coco/"
        image_root = "images/val2014"
    else:
        # karpathy split
        ann_file = "/home/appuser/datasets/flickr30k/annotations/flickr30k_test.json"
        data_root = "/home/appuser/datasets/flickr30k/"
        image_root = "images/flickr30k-images"

    text_loader = get_loader_text(ann_file, data_root, image_root, args.batch_size, transform)
    text_features = get_text_feature(model, text_loader)

    image_loader, txt2img, img2txt = get_loader_image(
        ann_file, data_root, image_root, args.batch_size, transform
    )
    image_features = get_image_feature(model, image_loader)

    similarity_scores = image_features.cuda() @ text_features.cuda().t()
    similarity_scores = similarity_scores
    t2i_dict, i2t_dict = compute_retrieval(similarity_scores, txt2img, img2txt)
    print('Image-to-Text retrieval', i2t_dict)
    print('Text-to-Image retrieval', t2i_dict)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="ZeroShot")
    parser.add_argument("--batch-size", default=64, type=int)
    parser.add_argument("--dataset", default="coco", type=str, help='coco or flickr')
    args = parser.parse_args()
    main(args)

src/retrieval.py:

import json
import os

from open_clip import tokenize
from PIL import Image
from torch.utils.data import DataLoader, Dataset

class TextDataset(Dataset):

    def __init__(self, text_data, tokenizer):
        self.tokenizer = tokenizer
        self.caption = text_data

    def __len__(self):
        return len(self.caption)

    def __getitem__(self, index):
        text_data = self.caption[index]
        # optional
        # text_data = 'a photo of ' + text_data
        text_token = self.tokenizer(text_data)
        return text_token

class CaptionsDataset(Dataset):

    def __init__(self, ann_file, transform, data_root, image_root):
        self.ann_file = json.load(open(ann_file, 'r'))
        self.transform = transform
        self.image_root = image_root
        self.caption = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        txt_num = 0
        for num, line in enumerate(self.ann_file):
            image_name = line['image'].split('/')[1]
            image_path = os.path.join(data_root, image_root, image_name)
            self.image.append(image_path)
            self.caption += line['caption']
            for i in range(txt_num, txt_num + len(line['caption'])):
                self.txt2img[i] = num
                if num not in self.img2txt.keys():
                    self.img2txt[num] = [i]
                else:
                    self.img2txt[num].append(i)
            txt_num += len(line['caption'])

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):
        image_path = os.path.join(self.image[index])
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        return image, index

def get_loader_image(ann_file, data_root, image_root, batch_size, preprocess):
    valid_dataset = CaptionsDataset(ann_file, preprocess, data_root, image_root)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size)
    return valid_dataloader, valid_dataset.txt2img, valid_dataset.img2txt

def get_loader_text(ann_file, data_root, image_root, batch_size, preprocess):
    valid_dataset = CaptionsDataset(ann_file, preprocess, data_root, image_root)
    text_dataset = TextDataset(valid_dataset.caption, tokenize)
    valid_dataloader = DataLoader(text_dataset, batch_size=batch_size, shuffle=False)
    return valid_dataloader
Hritikbansal commented 6 months ago

Hi,

Thank you for your interest in our work. Here is the code to perform zero-shot retrieval from our repo. It is adapted from the ALBEF repo.