ChenDelong1999 / RemoteCLIP

🛰️ Official repository of paper "RemoteCLIP: A Vision Language Foundation Model for Remote Sensing" (IEEE TGRS)
https://arxiv.org/abs/2306.11029
Apache License 2.0
228 stars 13 forks source link

RET-3 Deduplication #13

Closed zilunzhang closed 7 months ago

zilunzhang commented 7 months ago

Hi,

Deduplication: Remove the (almost) same image presence in both the test set of RSICD (one of the evaluation sets) and the training set of RSITMD (part of RET-3), and vice versa.

In the paper you mentioned, "We generate p-Hash values for all images and used these values to detect duplicate images. If the number of different digits between two images is less than threshold 2, they are considered duplicates. Finally, the number of removed duplicated samples ranges from 40 to 3k in different datasets"

Would you mind providing a filename list of the de-duplicated images from RET-3 (which not appear in the test set of RSICD and RSITMD)?

Thanks

zilunzhang commented 7 months ago

I have written one, but I am not sure if it is correct. Please have a look.

import numpy as np
from PIL import Image
import io
import random
import argparse
import pandas as pd
from tqdm import tqdm
import os
import time
import json
import pdb
import imagehash
import pickle as pkl

Image.MAX_IMAGE_PIXELS = None
HASH_DIST_THRESHOLD = 2.0

def hash_dist(hash1, hash2):
    return np.abs(hash1 - hash2)

def check(list1, val):
    return all(x > val for x in list1)

def make_ft_csv(args):
    rsicd_json = json.load(open(os.path.join(args.rsicd_dir, "dataset_rsicd.json"), "r"))["images"]
    rsicd_train_filenames = open(os.path.join(args.rsicd_dir, "split_files", "train_filename.txt"), "r").readlines()
    rsicd_train_filenames = [os.path.join(args.rsicd_dir, "RSICD_images", rsicd_train_filename.strip()) for rsicd_train_filename in rsicd_train_filenames]
    rsicd_test_filenames = open(os.path.join(args.rsicd_dir, "split_files", "test_filename.txt"), "r").readlines()
    rsicd_test_filenames = [os.path.join(args.rsicd_dir, "RSICD_images", rsicd_test_filename.strip()) for rsicd_test_filename in rsicd_test_filenames]

    rsitmd_json = json.load(open(os.path.join(args.rsitmd_dir, "dataset_RSITMD.json"), "r"))["images"]
    rsitmd_train_filenames = open(os.path.join(args.rsitmd_dir, "split_files", "train_filename.txt"), "r").readlines()
    rsitmd_train_filenames = [os.path.join(args.rsitmd_dir, "images", rsitmd_train_filename.strip()) for rsitmd_train_filename in rsitmd_train_filenames]
    rsitmd_test_filenames = open(os.path.join(args.rsitmd_dir, "split_files", "test_filename.txt"), "r").readlines()
    rsitmd_test_filenames = [os.path.join(args.rsitmd_dir, "images", rsitmd_test_filename.strip()) for rsitmd_test_filename in rsitmd_test_filenames]

    all_img_name_list_train = rsicd_train_filenames + rsitmd_train_filenames
    # all_img_name_list_train = [img_name.split(".")[0] for img_name in all_img_name_list_train]
    all_img_name_list_test = rsicd_test_filenames + rsitmd_test_filenames
    # all_img_name_list_test = [img_name.split(".")[0] for img_name in all_img_name_list_test]
    all_img_name_list_train = list(set(all_img_name_list_train))
    all_img_name_list_test = list(set(all_img_name_list_test))
    print(len(all_img_name_list_train), len(all_img_name_list_test))

    if os.path.exists("train_in_test_name_list.pkl") and os.path.exists("train_in_test_name_list.pkl"):
        train_in_test_name_list = pkl.load(open("train_in_test_name_list.pkl", "rb"))
        keep_train_path_list = pkl.load(open("keep_train_path_list.pkl", "rb"))
    else:
        train_img_hashs = dict()
        test_img_hashs = dict()
        for img_path_train in tqdm(all_img_name_list_train):
            train_img_hashs[img_path_train] = imagehash.phash(Image.open(img_path_train))

        for img_path_test in tqdm(all_img_name_list_test):
            test_img_hashs[img_path_test] = imagehash.phash(Image.open(img_path_test))

        train_in_test_name_list = []
        keep_train_path_list = []

        for img_path_train in tqdm(all_img_name_list_train):
            if os.path.exists(img_path_train):
                h_d_list = []
                for img_path_test in all_img_name_list_test:
                    h_d = hash_dist(train_img_hashs[img_path_train], test_img_hashs[img_path_test])
                    h_d_list.append(h_d)
                if check(h_d_list, HASH_DIST_THRESHOLD):
                    keep_train_path_list.append(img_path_train)
                else:
                    train_in_test_name_list.append(img_path_train)

        print(len(train_in_test_name_list), len(keep_train_path_list))

        pkl.dump(train_in_test_name_list, open("train_in_test_name_list.pkl", "wb"))
        pkl.dump(keep_train_path_list, open("keep_train_path_list.pkl", "wb"))

    rsicd_img_name_list = []
    rsicd_caption_list = []
    for rsicd_image_info in rsicd_json:
        tmp_rsicd_path = os.path.join(args.rsicd_dir, "RSICD_images", rsicd_image_info["filename"])
        if tmp_rsicd_path in keep_train_path_list:
            for rsicd_sentence in rsicd_image_info["sentences"]:
                rsicd_img_path = os.path.join(args.rsicd_dir, "RSICD_images", rsicd_image_info["filename"])
                assert os.path.exists(rsicd_img_path)
                rsicd_img_name_list.append(rsicd_img_path)
                rsicd_caption_list.append(rsicd_sentence["raw"])

    rsitmd_img_name_list = []
    rsitmd_caption_list = []
    for rsitmd_image_info in rsitmd_json:
        tmp_rsitmd_path = os.path.join(args.rsicd_dir, "images", rsitmd_image_info["filename"])
        if tmp_rsitmd_path in keep_train_path_list:
            for rsitmd_sentence in rsitmd_image_info["sentences"]:
                rsitmd_img_path = os.path.join(args.rsitmd_dir, "images", rsitmd_image_info["filename"])
                assert os.path.exists(rsitmd_img_path)
                rsitmd_img_name_list.append(rsitmd_img_path)
                rsitmd_caption_list.append(rsitmd_sentence["raw"])

    all_img_name_list = rsitmd_img_name_list + rsicd_img_name_list
    all_caption_list = rsitmd_caption_list + rsicd_caption_list

    print(len(all_img_name_list), len(set(all_img_name_list)))
    print(len(all_caption_list), len(set(all_caption_list)))

    ft_df = pd.DataFrame(
        {
            "filepath": all_img_name_list,
            "title": all_caption_list
        }
    )

    ft_df.to_csv(args.save_path, index=False)
    print(train_in_test_name_list)

def main():
    random.seed(2023)
    parser = argparse.ArgumentParser()
    parser.add_argument("--rsitmd_dir", type=str,
                        default="/home/zilun/RS5M_v5/data/rs5m_test_data/rsitmd",
                        help='RSITMD dir')
    parser.add_argument("--rsicd_dir", type=str,
                        default="/home/zilun/RS5M_v5/data/rs5m_test_data/rsicd",
                        help='RSICD dir')
    parser.add_argument("--save_path", type=str,
                        default="./ft_data_ret2.csv",
                        help='gt save path') 

    args = parser.parse_args()

    make_ft_csv(args)

if __name__ == "__main__":
    main()
gzqy1026 commented 7 months ago

I am sorry for any misunderstanding that may have occurred due to the content in the paper. The phash value between images is not a subtraction. Specifically, you can refer to the method 1-pHash block local probe in this deduplication code. https://github.com/xuehuachunsheng/DupImageDetection.

These are three lists containing the image names that were removed from the training set. rsicd_duplicate.txt rsitmd_duplicate.txt ucm_duplicate.txt

YiguoHe commented 2 months ago

The RSITMD and RSICD datasets have a data leakage issue where they might share some common images and descriptions. This problem persists even after removing the image lists provided by you.