dandelin / ViLT

Code for the ICML 2021 (long talk) paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"
Apache License 2.0
1.41k stars 208 forks source link

Got better results than in the paper: #51

Open JoanFM opened 2 years ago

JoanFM commented 2 years ago

Hey @dandelin ,

I just want to share the results I reproduced with my own recall implementation. Here is my ViltModel

from typing import List, Dict

import torch
from transformers import BertTokenizer

from vilt.modules import ViLTransformerSS

class ViltModel(ViLTransformerSS):
    def __init__(
            self,
            config,
            *args,
            **kwargs,
    ):
        super().__init__(config)
        self._config = config
        if torch.cuda.is_available():
            dev = "cuda:0"
        else:
            dev = "cpu"
        self._device = torch.device(dev)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.eval()

    @property
    def in_cuda(self):
        return next(self.parameters()).is_cuda

    def rank_query_vs_images(self, query: str, images: List):
        rank_scores = []
        encoded_input = self.tokenizer(query, return_tensors='pt')
        input_ids = encoded_input['input_ids'][:, :self._config['max_text_len']]
        mask = encoded_input['attention_mask'][:, :self._config['max_text_len']]
        in_cuda = self.in_cuda
        if in_cuda:
            input_ids = input_ids.to(self._device)
            mask = mask.to(self._device)
        batch = {'text_ids': input_ids, 'text_masks': mask, 'text_labels': None}
        # no masking
        for image in images:
            if in_cuda:
                image = image.to(self._device)
            batch['image'] = [image.unsqueeze(0)]
            score = self.rank_output(self.infer(batch)['cls_feats'])[:, 0]
            rank_scores.append(score.detach().cpu().item())
        return rank_scores

The compute recall method:

def compute_recall():
    import copy
    from vilt import config
    from vilt.transforms.pixelbert import pixelbert_transform
    from src.dataset.dataset import get_image_data_loader, get_captions_data_loader
    from src.evaluate import evaluate

    # Scared config is immutable object, so you need to deepcopy it.
    conf = copy.deepcopy(config.config())
    conf['load_path'] = VILT_BASE_MODEL_LOAD_PATH
    conf['test_only'] = True
    conf['max_text_len'] = 40
    conf['max_text_len'] = 40
    conf['data_root'] = '/hdd/master/tfm/arrow'
    conf['datasets'] = ['f30k']
    conf['batch_size'] = 1
    conf['per_gpu_batchsize'] = 1
    conf['draw_false_image'] = 0
    conf['num_workers'] = 1

    # You need to properly configure loss_names to initialize heads (0.5 means it initializes head, but ignores the
    # loss during training)
    loss_names = {
        'itm': 0.5,
        'mlm': 0,
        'mpp': 0,
        'vqa': 0,
        'imgcls': 0,
        'nlvr2': 0,
        'irtr': 1,
        'arc': 0,
    }
    conf['loss_names'] = loss_names

    if torch.cuda.is_available():
        dev = 'cuda:0'
    else:
        dev = 'cpu'
    device = torch.device(dev)

    print(f' conf for ViltModel {conf}')

    vilt_model = ViltModel(conf)
    vilt_model.to(device)

    image_dataset = get_image_data_loader(root=DATASET_ROOT_PATH,
                                          split_root=DATASET_SPLIT_ROOT_PATH,
                                          split='test',
                                          transform=pixelbert_transform(384),
                                          batch_size=1) # loading the images with the pixelBert transformation

    text_dataset = get_captions_data_loader(root=DATASET_ROOT_PATH,
                                            split_root=DATASET_SPLIT_ROOT_PATH,
                                            split='test',
                                            batch_size=1) # loading the captions with the pixelBert transformation

    images = []
    filenames = []
    for filenames_batch, images_batch in image_dataset:
        filenames.extend(filenames_batch)
        images.extend(images_batch)

    retrieved_image_filenames = []
    groundtruth_expected_image_filenames = []
    print(f' number of queries {len(text_dataset)}, against {len(images)}') # this leads to 5000 captions against 1000 images
    for matching_filename, query in text_dataset:
        filename = matching_filename[0]
        groundtruth_expected_image_filenames.append([filename])
        q = query[0]
        start = time.time()
        scores = vilt_model.rank_query_vs_images(q, images)
        print(f' time to rank a single query {time.time() - start}s')
        retrieved_image_filenames.append([f for _, f in sorted(zip(scores, filenames), reverse=True)])

    evaluate(['recall', 'reciprocal_rank'], retrieved_image_filenames,
             groundtruth_expected_image_filenames,
             [1, 5, 10, 20, 100, 200, 500, None],
             {}, print_results=True)

The obtained results are:

 Mean Recall@1 0.7584
 Mean Recall@5 0.9554
 Mean Recall@10 0.9826
 Mean Recall@20 0.9932
 Mean Recall@100 0.9998
 Mean Recall@200 1.0
 Mean Recall@500 1.0
 Mean Recall@None 1.0
 Mean Reciprocal rank 0.8449181004476226

I know that the results could differ from those in the paper, but this seems like an extremely good result? Is there something I am obviously doing wrong?

Thanks in advance

ahustr commented 2 years ago

haha, I got the same result, have you figured out?

JoanFM commented 2 years ago

Not really, did not check. I guess author is just very humble ! ;)