ChenDelong1999 / RemoteCLIP

🛰️ Official repository of paper "RemoteCLIP: A Vision Language Foundation Model for Remote Sensing" (IEEE TGRS)
https://arxiv.org/abs/2306.11029
Apache License 2.0
315 stars 22 forks source link

The retrieval evaluation code is not activated accurately in RSICD dataset. #7

Closed chagmgang closed 1 year ago

chagmgang commented 1 year ago

My evaluation code with your model (ViT-L-14) is below.

from huggingface_hub import hf_hub_download
import open_clip
import numpy as np
import torchvision
import os
import json
import torch
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
from clip_benchmark.metrics.zeroshot_retrieval import recall_at_k, batchify, dataloader_with_indices

class Dataset(torch.utils.data.Dataset):

    def __init__(self, base_path, transforms, filename):

        self.base_path = base_path
        self.transforms = transforms
        self.filename = filename

        self.data = self.load_annotation()

    def __getitem__(self, idx):
        image = self.data[idx]
        filename = os.path.join(self.base_path, image['filename'])
        raws = [i['raw'].replace(' .', '') for i in image['sentences']]
        return self.transforms(Image.open(filename)), raws

    def __len__(self):
        return len(self.data)

    def load_annotation(self):
        with open(self.filename, 'r') as f:
            data = json.load(f)
        images = data['images']
        test_images = list()
        for image in images:
            if image['split'] == 'test':
                test_images.append(image)
        return test_images

def main():

    device = torch.device('cuda')
    model_name = 'ViT-L-14'
    checkpoint_path = hf_hub_download(
        "chendelong/RemoteCLIP",
        f"RemoteCLIP-{model_name}.pt",
        cache_dir='checkpoints',
    )
    print(f'{model_name} is downloaded to {checkpoint_path}.')
    model, _, preprocess = open_clip.create_model_and_transforms(model_name)
    tokenizer = open_clip.get_tokenizer(model_name)
    path_to_your_checkpoints = 'checkpoints/models--chendelong--RemoteCLIP/snapshots/bf1d8a3ccf2ddbf7c875705e46373bfe542bce38'
    ckpt = torch.load(f"{path_to_your_checkpoints}/RemoteCLIP-{model_name}.pt", map_location="cpu")
    message = model.load_state_dict(ckpt)
    print(message)
    model = model.cuda().eval()

    dataset = Dataset(
        base_path='rsicd/RSICD_images',
        transforms=preprocess,
        filename='rsicd/RSICD_optimal/dataset_rsicd.json',
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=128,
        num_workers=4,
        drop_last=False,
        shuffle=False,
    )
    n_batches = len(dataloader)

    # list of batch of images embedding
    batch_images_emb_list = []
    # list of batch of text embedding
    batch_texts_emb_list = []
    # for each text, we collect the corresponding image index, as each image can have multiple corresponding texts
    texts_image_index = []

    dataloader = dataloader_with_indices(dataloader)
    for batch_images, batch_texts, inds in tqdm(dataloader, total=n_batches):
        batch_images = batch_images.to(device)
        batch_texts_tok = [tokenizer(text) for i, texts in enumerate(batch_texts) for text in texts]
        batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts]

        with torch.no_grad():
            batch_image_features = model.encode_image(batch_images)
            batch_text_features = [model.encode_text(t.to(device)) for t in batch_texts_tok]
            batch_text_features = torch.cat(batch_text_features)

        batch_images_emb = F.normalize(batch_image_features, dim=-1)
        batch_texts_emb = F.normalize(batch_text_features, dim=-1)

        batch_images_emb_list.append(batch_images_emb.cpu())
        batch_texts_emb_list.append(batch_texts_emb.cpu())
        texts_image_index.extend(batch_texts_image_index)

    batch_size = len(batch_images_emb_list[0])

    images_emb = torch.cat(batch_images_emb_list)
    texts_emb = torch.cat(batch_texts_emb_list)

    # get the score for each text and image pair
    scores  = texts_emb @ images_emb.t()

    positive_pairs = torch.zeros_like(scores, dtype=bool)
    positive_pairs[torch.arange(len(scores)), texts_image_index] = True
    metrics = {}
    recall_k_list = [1, 5, 10]
    for recall_k in recall_k_list:
        metrics[f"retrieval-image2text-R@{recall_k}"] = (batchify(recall_at_k, scores.T, positive_pairs.T, batch_size, device, k=recall_k)>0).float().mean().item() * 100

    for recall_k in recall_k_list:
        metrics[f"retrieval-text2image-R@{recall_k}"] = (batchify(recall_at_k, scores, positive_pairs, batch_size, device, k=recall_k)>0).float().mean().item() * 100

    metrics[f"retrieval-mean-recall"] = np.mean(list(metrics.values()))

    for key, item in metrics.items():
        metrics[key] = round(float(item), 2)

    for key in metrics.keys():
        print(key, metrics[key])

The evaluation number is printed as below.

retrieval-image2text-R@1 0.37
retrieval-image2text-R@5 1.56
retrieval-image2text-R@10 2.29
retrieval-text2image-R@1 0.6
retrieval-text2image-R@5 2.76
retrieval-text2image-R@10 4.68
retrieval-mean-recall 2.04
gzqy1026 commented 1 year ago

I think there may be a problem with the dataloader. Therefore, “batch_texts” may affects subsequent calculations. dataloader = torch.utils.data.DataLoader( dataset, batch_size=128, num_workers=4, drop_last=False, shuffle=False) You can try changing the above code as follows: dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, num_workers=4, collate_fn=get_dataset_collate_fn('mscoco_captions'), drop_last=False, shuffle=False) Please refer to this code for details: https://github.com/ChenDelong1999/ITRA/blob/ccf87ed79f4556b2bf0b1534d4e4507722a8b186/itra/evaluation/retrieval.py#L139-L144

ChenDelong1999 commented 1 year ago

@gzqy1026 For RSICD, there are five ground truth captions for each image, right?

If the retrieval model hits one of them, it would be considered a successful retrieval.

chagmgang commented 1 year ago

@ChenDelong1999 right. @gzqy1026 When apply get_dataset_collate_fn('mscoco_captions'), evaluation results is correctly printed. Thank you!

gzqy1026 commented 1 year ago

@ChenDelong1999 Yes, that's right.