Tung-I / Dual-awareness-Attention-for-Few-shot-Object-Detection

95 stars 16 forks source link

How to reproduce 18.6 COCO AP #29

Open ChengMeng-CM opened 2 years ago

ChengMeng-CM commented 2 years ago

In Tab. 1 of the paper, the proposed DANA are compared with other works in the general detection evaluation settings. And the 18.6 COCO AP is really a great performance. However, the provided evaluation code seems only contains the 1-way evaluation mode, which is not a general setting for real application of object detection. And the BA block and the CISA block only support one-class inference. I have tried modifying the code to infer each image for multiple times and merged the results to get a multi-class inference results, like attention-RPN, but the poor performance are gained. Is there some references for this question? Thanks.

Tung-I commented 2 years ago

The multi-class inference you mentioned belongs to DML-based method, such as prototypical networks, which can compare the target with multiple support feature prototypes simultaneously and thus perform multi-class inference at once. On the ohter hand, this work belongs to the matching-based method, which usually performs one-class inference at a time. If you want to perform multi-class inference, you can just infer the results class-by-class, and that's exactly how we conduct the evaluation. Since you can reproduce high performnace for each class individually, I can't see why it will degrade after merge the results of each class. I assume there might be some problems in your modified AP calculation code. Thanks.

ChengMeng-CM commented 2 years ago

The multi-class inference you mentioned belongs to DML-based method, such as prototypical networks, which can compare the target with multiple support feature prototypes simultaneously and thus perform multi-class inference at once. On the ohter hand, this work belongs to the matching-based method, which usually performs one-class inference at a time. If you want to perform multi-class inference, you can just infer the results class-by-class, and that's exactly how we conduct the evaluation. Since you can reproduce high performnace for each class individually, I can't see why it will degrade after merge the results of each class. I assume there might be some problems in your modified AP calculation code. Thanks.

Very appreciate for your reply! The modified inference.py is as follows:

from calendar import c
import os
import numpy as np
import argparse
import time
import pickle
import cv2
import sys
import torch
import torch.nn as nn
from torch.autograd import Variable
from tqdm import tqdm
from matplotlib import pyplot as plt
import random
from scipy.misc import imread

ROOT = os.getcwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))
LIB = os.path.join(ROOT,'lib')
if str(LIB) not in sys.path:
    sys.path.append(str(LIB))

from roi_data_layer.roidb import combined_roidb
from roi_data_layer.inference_loader import InferenceLoader
from roi_data_layer.general_test_loader import GeneralTestLoader
from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from model.rpn.bbox_transform import clip_boxes
from model.roi_layers import nms
from model.rpn.bbox_transform import bbox_transform_inv
from model.utils.net_utils import save_net, load_net, vis_detections
from model.utils.fsod_logger import FSODInferenceLogger
from model.utils.blob import prep_im_for_blob
from utils import *

def get_supports(imdb, sup_dir, num_shot=3):
    support_dir = os.path.join(CWD, 'data/supports', sup_dir)
    support_im_size = 320
    # self.testing_shot = num_shot

    num_classes = len(imdb.classes)

    support_pool = [[] for i in range(num_classes)]
    label_to_cls_name = dict(list(zip(list(range(num_classes)), imdb.classes)))
    for _label in range(1, num_classes):
        cls_name = label_to_cls_name[_label]
        cls_dir = os.path.join(support_dir, cls_name)
        support_im_paths = [str(_p) for _p in list(Path(cls_dir).glob('*.jpg'))]
        if len(support_im_paths) == 0:
            raise Exception(f'support data not found in {cls_dir}')
        random.seed(0)  # fix the shots
        support_im_paths = random.sample(support_im_paths, k=num_shot)
        support_pool[_label].extend(support_im_paths)
    support_data_all = np.zeros((num_classes-1 , num_shot, 3, support_im_size, support_im_size), dtype=np.float32) # [N,K,3,H,W]

    for class_ind, supports_per_class in enumerate(support_pool):
        for i, _path in enumerate(supports_per_class):
            support_im = imread(_path)[:,:,::-1]  # rgb -> bgr
            target_size = np.min(support_im.shape[0:2])  # don't change the size
            support_im, _ = prep_im_for_blob(support_im, cfg.PIXEL_MEANS, target_size, cfg.TRAIN.MAX_SIZE)
            _h, _w = support_im.shape[0], support_im.shape[1]
            if _h > _w:
                resize_scale = float(support_im_size) / float(_h)
                unfit_size = int(_w * resize_scale)
                support_im = cv2.resize(support_im, (unfit_size, support_im_size), interpolation=cv2.INTER_LINEAR)
            else:
                resize_scale = float(support_im_size) / float(_w)
                unfit_size = int(_h * resize_scale)
                support_im = cv2.resize(support_im, (support_im_size, unfit_size), interpolation=cv2.INTER_LINEAR)
            h, w = support_im.shape[0], support_im.shape[1]
            support_data_all[class_ind-1, i, :, :h, :w] = np.transpose(support_im, (2, 0, 1)) 
    supports = torch.from_numpy(support_data_all)
    return supports

if __name__ == '__main__':

    args = parse_args()
    print(args)
    cfg_from_file(args.cfg_file)
    cfg_from_list(args.set_cfgs)

    # prepare roidb
    cfg.TRAIN.USE_FLIPPED = False
    imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdbval_name, False)
    CWD = os.getcwd()
    support_dir = os.path.join(CWD, 'data/supports', args.sup_dir)

    # load dir
    input_dir = os.path.join(args.load_dir, "train/checkpoints")
    if not os.path.exists(input_dir):
        raise Exception('There is no input directory for loading network from ' + input_dir)
    load_name = os.path.join(input_dir,
        'model_{}_{}.pth'.format(args.checkepoch, args.checkpoint))

    # initilize the network
    classes = ['fg', 'bg']
    model = get_model(args.net, pretrained=False, way=args.way, shot=args.shot, classes=classes)
    print("load checkpoint %s" % (load_name))
    checkpoint = torch.load(load_name)
    model.load_state_dict(checkpoint['model'])
    if len(args.device.split(',')) > 1:
        model = model.module
    if 'pooling_mode' in checkpoint.keys():
        cfg.POOLING_MODE = checkpoint['pooling_mode']
    print('load model successfully!')
    cfg.CUDA = True
    model.cuda()
    model.eval()

    # initilize the tensor holders
    holders = prepare_var(support=True)
    im_data = holders[0]
    im_info = holders[1]
    num_boxes = holders[2]
    gt_boxes = holders[3]
    support_ims_all = holders[4]

    # prepare holder for predicted boxes
    start = time.time()
    max_per_image = 100
    thresh = 0.05
    num_images = len(imdb.image_index)
    all_boxes = [[[] for _ in range(num_images)]
                for _ in range(imdb.num_classes)]
    # print(len(all_boxes))
    # raise SystemExit
    _t = {'im_detect': time.time(), 'misc': time.time()}

    model.eval()
    empty_array = np.transpose(np.array([[],[],[],[],[]]), (1,0))

    imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdbval_name, False)
    imdb.competition_mode(on=True)
    dataset = InferenceLoader(0, imdb, roidb, ratio_list, ratio_index, support_dir, 
                            1, len(imdb._classes), num_shot=args.shot, training=False, normalize=False)

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
    data_iter = iter(dataloader)

    supports = get_supports(imdb, args.sup_dir, args.shot)

    for i in tqdm(range(num_images)):
        scores_all = {}
        boxes_all = {}
        box_deltas_all = {}
        pred_boxes_all = {}
        data = next(data_iter)
        with torch.no_grad():
            im_data.resize_(data[0].size()).copy_(data[0])
            im_info.resize_(data[1].size()).copy_(data[1])
            gt_boxes.resize_(data[2].size()).copy_(data[2])
            num_boxes.resize_(data[3].size()).copy_(data[3])
            # support_ims_all.resize_(data[4].size()).copy_(data[4])
            support_ims_all.resize_(supports.size()).copy_(supports)
        # N, K = support_ims.shape[1:3]
        # Hs, Ws = support_ims.shape[-2:]
        # support_ims = support_ims.reshape(N*K,3,Hs,Ws)
        # gt_classes = gt_boxes[0,:,4].unique().int().tolist()
        det_tic = time.time()
        # print(support_ims_all.shape)
        with torch.no_grad():
            for cls_ind, support_ims in enumerate(support_ims_all):
                support_ims = support_ims.unsqueeze(0)
                gt_class = cls_ind+1
                rois, cls_prob, bbox_pred, \
                rpn_loss_cls, rpn_loss_box, \
                RCNN_loss_cls, RCNN_loss_bbox, \
                rois_label = model(im_data, im_info, gt_boxes, num_boxes, support_ims)

                scores_all[gt_class] = cls_prob.data
                boxes_all[gt_class]= rois.data[:,:,1:5]
                box_deltas_all[gt_class] = bbox_pred.data
        det_toc = time.time()
        detect_time = det_toc - det_tic
        misc_tic = time.time()

        # scores = cls_prob.data
        # boxes = rois.data[:, :, 1:5]

        # Apply bounding-box regression deltas
        # box_deltas = bbox_pred.data
        if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
        # Optionally normalize targets by a precomputed mean and stdev
            for cls_id, box_deltas in box_deltas_all.items():
                box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
                        + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
                box_deltas = box_deltas.view(1, -1, 4)
                box_deltas_all[cls_id] = box_deltas

        for cls_id, boxes in boxes_all.items():
            box_deltas = box_deltas_all[cls_id]
            scores = scores_all[cls_id]

            pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
            pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)

            # re-scale boxes to the origin img scale
            pred_boxes /= data[1][0][2].item()

            scores = scores.squeeze()
            pred_boxes = pred_boxes.squeeze()
            pred_boxes_all[cls_id] = pred_boxes

        for cls_id, scores in scores_all.items():
            # if j != gt_boxes[0, 0, 4]:
            #     all_boxes[j][i] = empty_array
            #     continue
            # scores = scores_all[cls_id]
            pred_boxes = pred_boxes_all[cls_id]

            inds = torch.nonzero(scores[:,1]>thresh).view(-1)
            if inds.numel() > 0:
                cls_scores = scores[:,1][inds]
                cls_boxes = pred_boxes[inds, :]
                cls_dets = NMS(cls_boxes, cls_scores)
                all_boxes[cls_id][i] = cls_dets.cpu().numpy()
            else:
                all_boxes[cls_id][i] = empty_array

        misc_toc = time.time()
        nms_time = misc_toc - misc_tic

        # if args.imlog:
        #     origin_im = im_data[0].permute(1, 2, 0).contiguous().cpu().numpy()[:, :, ::-1]
        #     origin_im = origin_im - origin_im.min()
        #     origin_im /= origin_im.max()
        #     gt_im = origin_im.copy()
        #     pt_im = origin_im.copy()
        #     np_gt_boxes = gt_boxes[0]
        #     for n in range(np_gt_boxes.shape[0]):
        #         box = np_gt_boxes[n].clone()
        #         cv2.rectangle(gt_im, (box[0], box[1]), (box[2], box[3]), (0.1, 1, 0.1), 2)
        #     plt.imshow(gt_im)
        #     plt.show()
        #     sup_im = support_ims[0][0].permute(1, 2, 0).contiguous().cpu().numpy()[:, :, ::-1]
        #     sup_im = sup_im - sup_im.min()
        #     sup_im /= sup_im.max()
        #     plt.imshow(sup_im)
        #     plt.show()
        #     raise Exception(' ')

            # raise Exception(' ')
            # cv2.rectangle(im, (box[0], box[1]), (box[0] + box[2], box[1] + box[3]), (20, 255, 20), 2)
            # tb_logger.write(i, gt, support_ims, predict, save_im=True)

    sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s   \r' \
            .format(i + 1, num_images, detect_time, nms_time))
    sys.stdout.flush()

    output_dir = os.path.join(CWD, 'inference_output', args.eval_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    det_file = os.path.join(output_dir, 'detections.pkl')
    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
    print('Evaluating detections')
    imdb.evaluate_detections(all_boxes, output_dir)

Following changes have been made:

  1. Evaluation set have been changed to contain all 5000 images of coco2017 val set, and base-class annotations have been removed.
  2. Only k images are randomly selected for each class as the support for the entire inference stage rather than randomly sampling k samples in 24 candidate support images for each inference iteration. In other words, all support images will be loaded at one time before inference.
  3. Each image in val set is inferred for 20 times considering all novel classes.

I have tested the performance in the setting of 3-shot without fine-tuning. This is the setting I interested in. Results are shown below:

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.025
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.050
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.022
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.018
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.031
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.038
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.138
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.275
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.293
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.151
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.300
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.440

The AR performance seems good, but the AP performance is not as good as the 1-way performance I tested before. Is this normal?