lxn96 / ICPE

The offical code for paper "Breaking Immutable: Information-Coupled Prototype Elaboration for Few-Shot Object Detection"
Apache License 2.0
29 stars 2 forks source link

How to reproduce the results? #2

Open Alan-lab opened 1 year ago

Alan-lab commented 1 year ago

This is an excellent work, but how to reproduce the results in the paper?I follow the prompts of the readme, the environment, dataset, code and config file are all consistent, But on the voc dataset, the actual reproduced results are at least 8 points worse than the results provided by the paper,Looking forward to your reply!

lxn96 commented 1 year ago

Thank you for your interest in our work. For FSOD, different support images have a great impact on model performance. The results presented in the paper are the average of multiple randomized experiments. You can try to generate different fewlist files for experiments. Below is my code to generate fewlist, the seed can be changed to generate different support sets.

import argparse
import random
import os
import numpy as np
import mmcv
import xml.etree.ElementTree as ET

VOC = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
       'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
dataset_classes = {'voc': VOC}
few_nums = [1, 2, 3, 5, 10]

def is_valid(xmlpath, cls_name):
    tree = ET.parse(xmlpath)
    tree_root = tree.getroot()
    flag = 0
    for obj in tree_root.findall('object'):
        name = obj.find('name').text
        if name != cls_name:
            continue
        flag = 1
    if flag:
        return True
    else:
        return False

def gen_cls_txt(classes, root):
    print('-----------------------------------------------------------')
    print('--------------- Generating class list ---------------------')
    if not os.path.exists(os.path.join(root, 'ClassList')):
        os.mkdir(os.path.join(root, 'ClassList'))
    image_ids = mmcv.list_from_file(os.path.join(root, 'ImageSets/Main/trainval.txt'))
    for class_name in classes:
        print('parsing {}'.format(class_name))
        list_file = os.path.join(root, 'ClassList/{}.txt'.format(class_name))
        with open(list_file, 'w') as out_f:
            for id in image_ids:
                xml_path = os.path.join(root, 'Annotations/{}.xml'.format(id))
                tree = ET.parse(xml_path)
                tree_root = tree.getroot()
                flag = 0
                for obj in tree_root.findall('object'):
                    name = obj.find('name').text
                    if name != class_name:
                        continue
                    flag = 1
                if flag:
                    out_f.write('{}\n'.format(id))

def gen_image_fewlist(classes, root):
    gen_cls_txt(classes, root)
    print('-----------------------------------------------------------')
    print('----------- Generating fewlist  (images) ------------------')
    for i, clsname in enumerate(classes):
        print('===> Processing class: {}'.format(clsname))
        with open(os.path.join(root, 'ClassList/{}.txt'.format(clsname)), 'r') as f:
            name_list = f.readlines()
        num = max(few_nums)
        random.seed(i)
        # selected_list = random.sample(name_list, num)
        selected_list = []
        while len(selected_list) < num:
            x = random.sample(name_list, num)[0]
            xmlpath = os.path.join(root, 'Annotations/{}.xml'.format(x.strip()))
            if not is_valid(xmlpath, clsname):
                continue
            selected_list.append(x)
        for n in few_nums:
            out_path = os.path.join(root, 'benchmark_{}shot'.format(n))
            if not os.path.exists(out_path):
                os.mkdir(out_path)
            with open(os.path.join(out_path, '{}shot_{}_train.txt'.format(n, clsname)), 'w') as f:
                for i in range(n):
                    f.write(selected_list[i])

def get_bbox_fewlist(root, shot, classes):
    rootfile = os.path.join(root, 'ImageSets/Main/trainval.txt')
    with open(rootfile) as f:
        names = f.readlines()
    random.seed(2018)
    cls_lists = [[] for _ in range(len(classes))]
    cls_counts = [0] * len(classes)
    cls2id = {cls: i for i, cls in enumerate(classes)}
    while min(cls_counts) < shot:
        imgid = random.sample(names, 1)[0]
        xmlpath = os.path.join(root, 'Annotations/{}.xml'.format(imgid.strip()))
        tree = ET.parse(xmlpath)
        tree_root = tree.getroot()
        for obj in tree_root.findall('object'):
            name = obj.find('name').text
            ci = cls2id[name]
            if cls_counts[ci] < shot:
                cls_counts[ci] += 1
                if imgid not in cls_lists[ci]:
                    cls_lists[ci].append(imgid)
        names.remove(imgid)
    return cls_lists

def gen_bbox_fewlist(classes, root):
    print('-----------------------------------------------------------')
    print('----------- Generating fewlist  (bboxes) ------------------')
    for n in few_nums:
        print('===> On {} shot ...'.format(n))
        filelists = get_bbox_fewlist(root, n, classes)
        out_path = os.path.join(root, 'benchmark_{}shot'.format(n))
        if not os.path.exists(out_path):
            os.mkdir(out_path)
        for i, clsname in enumerate(classes):
            print('   | Processing class: {}'.format(clsname))
            with open(os.path.join(out_path, 'box_{}shot_{}_train.txt'.format(n, clsname)), 'w') as f:
                for name in filelists[i]:
                    f.write(name)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--type', type=str, default='box', choices=['box', 'img', 'both'])
    parser.add_argument('--dataset', type=str, default='voc', choices=['voc'])
    parser.add_argument('--dataroot', type=str, default='/workspace/data/voc')
    args = parser.parse_args()

    classes = dataset_classes[args.dataset]
    if args.type is None or args.type == 'box':
        gen_bbox_fewlist(classes, args.dataroot)
    elif args.type == 'img':
        gen_image_fewlist(classes, args.dataroot)
    elif args.type == 'both':
        gen_image_fewlist(classes, args.dataroot)
        gen_bbox_fewlist(classes, args.dataroot)
Alan-lab commented 1 year ago

Thank you for your interest in our work. For FSOD, different support images have a great impact on model performance. The results presented in the paper are the average of multiple randomized experiments. You can try to generate different fewlist files for experiments. Below is my code to generate fewlist, the seed can be changed to generate different support sets.

import argparse
import random
import os
import numpy as np
import mmcv
import xml.etree.ElementTree as ET

VOC = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
       'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
dataset_classes = {'voc': VOC}
few_nums = [1, 2, 3, 5, 10]

def is_valid(xmlpath, cls_name):
    tree = ET.parse(xmlpath)
    tree_root = tree.getroot()
    flag = 0
    for obj in tree_root.findall('object'):
        name = obj.find('name').text
        if name != cls_name:
            continue
        flag = 1
    if flag:
        return True
    else:
        return False

def gen_cls_txt(classes, root):
    print('-----------------------------------------------------------')
    print('--------------- Generating class list ---------------------')
    if not os.path.exists(os.path.join(root, 'ClassList')):
        os.mkdir(os.path.join(root, 'ClassList'))
    image_ids = mmcv.list_from_file(os.path.join(root, 'ImageSets/Main/trainval.txt'))
    for class_name in classes:
        print('parsing {}'.format(class_name))
        list_file = os.path.join(root, 'ClassList/{}.txt'.format(class_name))
        with open(list_file, 'w') as out_f:
            for id in image_ids:
                xml_path = os.path.join(root, 'Annotations/{}.xml'.format(id))
                tree = ET.parse(xml_path)
                tree_root = tree.getroot()
                flag = 0
                for obj in tree_root.findall('object'):
                    name = obj.find('name').text
                    if name != class_name:
                        continue
                    flag = 1
                if flag:
                    out_f.write('{}\n'.format(id))

def gen_image_fewlist(classes, root):
    gen_cls_txt(classes, root)
    print('-----------------------------------------------------------')
    print('----------- Generating fewlist  (images) ------------------')
    for i, clsname in enumerate(classes):
        print('===> Processing class: {}'.format(clsname))
        with open(os.path.join(root, 'ClassList/{}.txt'.format(clsname)), 'r') as f:
            name_list = f.readlines()
        num = max(few_nums)
        random.seed(i)
        # selected_list = random.sample(name_list, num)
        selected_list = []
        while len(selected_list) < num:
            x = random.sample(name_list, num)[0]
            xmlpath = os.path.join(root, 'Annotations/{}.xml'.format(x.strip()))
            if not is_valid(xmlpath, clsname):
                continue
            selected_list.append(x)
        for n in few_nums:
            out_path = os.path.join(root, 'benchmark_{}shot'.format(n))
            if not os.path.exists(out_path):
                os.mkdir(out_path)
            with open(os.path.join(out_path, '{}shot_{}_train.txt'.format(n, clsname)), 'w') as f:
                for i in range(n):
                    f.write(selected_list[i])

def get_bbox_fewlist(root, shot, classes):
    rootfile = os.path.join(root, 'ImageSets/Main/trainval.txt')
    with open(rootfile) as f:
        names = f.readlines()
    random.seed(2018)
    cls_lists = [[] for _ in range(len(classes))]
    cls_counts = [0] * len(classes)
    cls2id = {cls: i for i, cls in enumerate(classes)}
    while min(cls_counts) < shot:
        imgid = random.sample(names, 1)[0]
        xmlpath = os.path.join(root, 'Annotations/{}.xml'.format(imgid.strip()))
        tree = ET.parse(xmlpath)
        tree_root = tree.getroot()
        for obj in tree_root.findall('object'):
            name = obj.find('name').text
            ci = cls2id[name]
            if cls_counts[ci] < shot:
                cls_counts[ci] += 1
                if imgid not in cls_lists[ci]:
                    cls_lists[ci].append(imgid)
        names.remove(imgid)
    return cls_lists

def gen_bbox_fewlist(classes, root):
    print('-----------------------------------------------------------')
    print('----------- Generating fewlist  (bboxes) ------------------')
    for n in few_nums:
        print('===> On {} shot ...'.format(n))
        filelists = get_bbox_fewlist(root, n, classes)
        out_path = os.path.join(root, 'benchmark_{}shot'.format(n))
        if not os.path.exists(out_path):
            os.mkdir(out_path)
        for i, clsname in enumerate(classes):
            print('   | Processing class: {}'.format(clsname))
            with open(os.path.join(out_path, 'box_{}shot_{}_train.txt'.format(n, clsname)), 'w') as f:
                for name in filelists[i]:
                    f.write(name)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--type', type=str, default='box', choices=['box', 'img', 'both'])
    parser.add_argument('--dataset', type=str, default='voc', choices=['voc'])
    parser.add_argument('--dataroot', type=str, default='/workspace/data/voc')
    args = parser.parse_args()

    classes = dataset_classes[args.dataset]
    if args.type is None or args.type == 'box':
        gen_bbox_fewlist(classes, args.dataroot)
    elif args.type == 'img':
        gen_image_fewlist(classes, args.dataroot)
    elif args.type == 'both':
        gen_image_fewlist(classes, args.dataroot)
        gen_bbox_fewlist(classes, args.dataroot)

Oh, if I understand correctly, it means that the results in the paper are not finetuned according to the voc division of tfa, but multiple random selections are made in finetune to show the best results, is it right?

Alan-lab commented 1 year ago

1677918913085 TFA code link: https://github.com/ucbdrive/few-shot-object-detection TFA: ICML 2020 paper Frustratingly Simple Few-Shot Object Detection

lxn96 commented 1 year ago

We generate fewlist by following the work Few-shot Object Detection via Feature Reweighting (ICCV2019) code: https://github.com/bingykang/Fewshot_Detection

666wzl666 commented 1 week ago

This is an excellent work, but how to reproduce the results in the paper?I follow the prompts of the readme, the environment, dataset, code and config file are all consistent, But on the voc dataset, the actual reproduced results are at least 8 points worse than the results provided by the paper,Looking forward to your reply!

Please when you run the training on VOC "python tools/detection/train.py configs/detection/icpe/voc/ICPE_voc-split1_base-training.py", I'm getting that error.Do you know how to fix it