I just want to share the results I reproduced with my own recall implementation. Here is my ViltModel
from typing import List, Dict
import torch
from transformers import BertTokenizer
from vilt.modules import ViLTransformerSS
class ViltModel(ViLTransformerSS):
def __init__(
self,
config,
*args,
**kwargs,
):
super().__init__(config)
self._config = config
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
self._device = torch.device(dev)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.eval()
@property
def in_cuda(self):
return next(self.parameters()).is_cuda
def rank_query_vs_images(self, query: str, images: List):
rank_scores = []
encoded_input = self.tokenizer(query, return_tensors='pt')
input_ids = encoded_input['input_ids'][:, :self._config['max_text_len']]
mask = encoded_input['attention_mask'][:, :self._config['max_text_len']]
in_cuda = self.in_cuda
if in_cuda:
input_ids = input_ids.to(self._device)
mask = mask.to(self._device)
batch = {'text_ids': input_ids, 'text_masks': mask, 'text_labels': None}
# no masking
for image in images:
if in_cuda:
image = image.to(self._device)
batch['image'] = [image.unsqueeze(0)]
score = self.rank_output(self.infer(batch)['cls_feats'])[:, 0]
rank_scores.append(score.detach().cpu().item())
return rank_scores
The compute recall method:
def compute_recall():
import copy
from vilt import config
from vilt.transforms.pixelbert import pixelbert_transform
from src.dataset.dataset import get_image_data_loader, get_captions_data_loader
from src.evaluate import evaluate
# Scared config is immutable object, so you need to deepcopy it.
conf = copy.deepcopy(config.config())
conf['load_path'] = VILT_BASE_MODEL_LOAD_PATH
conf['test_only'] = True
conf['max_text_len'] = 40
conf['max_text_len'] = 40
conf['data_root'] = '/hdd/master/tfm/arrow'
conf['datasets'] = ['f30k']
conf['batch_size'] = 1
conf['per_gpu_batchsize'] = 1
conf['draw_false_image'] = 0
conf['num_workers'] = 1
# You need to properly configure loss_names to initialize heads (0.5 means it initializes head, but ignores the
# loss during training)
loss_names = {
'itm': 0.5,
'mlm': 0,
'mpp': 0,
'vqa': 0,
'imgcls': 0,
'nlvr2': 0,
'irtr': 1,
'arc': 0,
}
conf['loss_names'] = loss_names
if torch.cuda.is_available():
dev = 'cuda:0'
else:
dev = 'cpu'
device = torch.device(dev)
print(f' conf for ViltModel {conf}')
vilt_model = ViltModel(conf)
vilt_model.to(device)
image_dataset = get_image_data_loader(root=DATASET_ROOT_PATH,
split_root=DATASET_SPLIT_ROOT_PATH,
split='test',
transform=pixelbert_transform(384),
batch_size=1) # loading the images with the pixelBert transformation
text_dataset = get_captions_data_loader(root=DATASET_ROOT_PATH,
split_root=DATASET_SPLIT_ROOT_PATH,
split='test',
batch_size=1) # loading the captions with the pixelBert transformation
images = []
filenames = []
for filenames_batch, images_batch in image_dataset:
filenames.extend(filenames_batch)
images.extend(images_batch)
retrieved_image_filenames = []
groundtruth_expected_image_filenames = []
print(f' number of queries {len(text_dataset)}, against {len(images)}') # this leads to 5000 captions against 1000 images
for matching_filename, query in text_dataset:
filename = matching_filename[0]
groundtruth_expected_image_filenames.append([filename])
q = query[0]
start = time.time()
scores = vilt_model.rank_query_vs_images(q, images)
print(f' time to rank a single query {time.time() - start}s')
retrieved_image_filenames.append([f for _, f in sorted(zip(scores, filenames), reverse=True)])
evaluate(['recall', 'reciprocal_rank'], retrieved_image_filenames,
groundtruth_expected_image_filenames,
[1, 5, 10, 20, 100, 200, 500, None],
{}, print_results=True)
The obtained results are:
Mean Recall@1 0.7584
Mean Recall@5 0.9554
Mean Recall@10 0.9826
Mean Recall@20 0.9932
Mean Recall@100 0.9998
Mean Recall@200 1.0
Mean Recall@500 1.0
Mean Recall@None 1.0
Mean Reciprocal rank 0.8449181004476226
I know that the results could differ from those in the paper, but this seems like an extremely good result? Is there something I am obviously doing wrong?
Hey @dandelin ,
I just want to share the results I reproduced with my own recall implementation. Here is my ViltModel
The compute recall method:
The obtained results are:
I know that the results could differ from those in the paper, but this seems like an extremely good result? Is there something I am obviously doing wrong?
Thanks in advance