Open ChengMeng-CM opened 2 years ago
The multi-class inference you mentioned belongs to DML-based method, such as prototypical networks, which can compare the target with multiple support feature prototypes simultaneously and thus perform multi-class inference at once. On the ohter hand, this work belongs to the matching-based method, which usually performs one-class inference at a time. If you want to perform multi-class inference, you can just infer the results class-by-class, and that's exactly how we conduct the evaluation. Since you can reproduce high performnace for each class individually, I can't see why it will degrade after merge the results of each class. I assume there might be some problems in your modified AP calculation code. Thanks.
The multi-class inference you mentioned belongs to DML-based method, such as prototypical networks, which can compare the target with multiple support feature prototypes simultaneously and thus perform multi-class inference at once. On the ohter hand, this work belongs to the matching-based method, which usually performs one-class inference at a time. If you want to perform multi-class inference, you can just infer the results class-by-class, and that's exactly how we conduct the evaluation. Since you can reproduce high performnace for each class individually, I can't see why it will degrade after merge the results of each class. I assume there might be some problems in your modified AP calculation code. Thanks.
Very appreciate for your reply! The modified inference.py
is as follows:
from calendar import c
import os
import numpy as np
import argparse
import time
import pickle
import cv2
import sys
import torch
import torch.nn as nn
from torch.autograd import Variable
from tqdm import tqdm
from matplotlib import pyplot as plt
import random
from scipy.misc import imread
ROOT = os.getcwd()
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
LIB = os.path.join(ROOT,'lib')
if str(LIB) not in sys.path:
sys.path.append(str(LIB))
from roi_data_layer.roidb import combined_roidb
from roi_data_layer.inference_loader import InferenceLoader
from roi_data_layer.general_test_loader import GeneralTestLoader
from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from model.rpn.bbox_transform import clip_boxes
from model.roi_layers import nms
from model.rpn.bbox_transform import bbox_transform_inv
from model.utils.net_utils import save_net, load_net, vis_detections
from model.utils.fsod_logger import FSODInferenceLogger
from model.utils.blob import prep_im_for_blob
from utils import *
def get_supports(imdb, sup_dir, num_shot=3):
support_dir = os.path.join(CWD, 'data/supports', sup_dir)
support_im_size = 320
# self.testing_shot = num_shot
num_classes = len(imdb.classes)
support_pool = [[] for i in range(num_classes)]
label_to_cls_name = dict(list(zip(list(range(num_classes)), imdb.classes)))
for _label in range(1, num_classes):
cls_name = label_to_cls_name[_label]
cls_dir = os.path.join(support_dir, cls_name)
support_im_paths = [str(_p) for _p in list(Path(cls_dir).glob('*.jpg'))]
if len(support_im_paths) == 0:
raise Exception(f'support data not found in {cls_dir}')
random.seed(0) # fix the shots
support_im_paths = random.sample(support_im_paths, k=num_shot)
support_pool[_label].extend(support_im_paths)
support_data_all = np.zeros((num_classes-1 , num_shot, 3, support_im_size, support_im_size), dtype=np.float32) # [N,K,3,H,W]
for class_ind, supports_per_class in enumerate(support_pool):
for i, _path in enumerate(supports_per_class):
support_im = imread(_path)[:,:,::-1] # rgb -> bgr
target_size = np.min(support_im.shape[0:2]) # don't change the size
support_im, _ = prep_im_for_blob(support_im, cfg.PIXEL_MEANS, target_size, cfg.TRAIN.MAX_SIZE)
_h, _w = support_im.shape[0], support_im.shape[1]
if _h > _w:
resize_scale = float(support_im_size) / float(_h)
unfit_size = int(_w * resize_scale)
support_im = cv2.resize(support_im, (unfit_size, support_im_size), interpolation=cv2.INTER_LINEAR)
else:
resize_scale = float(support_im_size) / float(_w)
unfit_size = int(_h * resize_scale)
support_im = cv2.resize(support_im, (support_im_size, unfit_size), interpolation=cv2.INTER_LINEAR)
h, w = support_im.shape[0], support_im.shape[1]
support_data_all[class_ind-1, i, :, :h, :w] = np.transpose(support_im, (2, 0, 1))
supports = torch.from_numpy(support_data_all)
return supports
if __name__ == '__main__':
args = parse_args()
print(args)
cfg_from_file(args.cfg_file)
cfg_from_list(args.set_cfgs)
# prepare roidb
cfg.TRAIN.USE_FLIPPED = False
imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdbval_name, False)
CWD = os.getcwd()
support_dir = os.path.join(CWD, 'data/supports', args.sup_dir)
# load dir
input_dir = os.path.join(args.load_dir, "train/checkpoints")
if not os.path.exists(input_dir):
raise Exception('There is no input directory for loading network from ' + input_dir)
load_name = os.path.join(input_dir,
'model_{}_{}.pth'.format(args.checkepoch, args.checkpoint))
# initilize the network
classes = ['fg', 'bg']
model = get_model(args.net, pretrained=False, way=args.way, shot=args.shot, classes=classes)
print("load checkpoint %s" % (load_name))
checkpoint = torch.load(load_name)
model.load_state_dict(checkpoint['model'])
if len(args.device.split(',')) > 1:
model = model.module
if 'pooling_mode' in checkpoint.keys():
cfg.POOLING_MODE = checkpoint['pooling_mode']
print('load model successfully!')
cfg.CUDA = True
model.cuda()
model.eval()
# initilize the tensor holders
holders = prepare_var(support=True)
im_data = holders[0]
im_info = holders[1]
num_boxes = holders[2]
gt_boxes = holders[3]
support_ims_all = holders[4]
# prepare holder for predicted boxes
start = time.time()
max_per_image = 100
thresh = 0.05
num_images = len(imdb.image_index)
all_boxes = [[[] for _ in range(num_images)]
for _ in range(imdb.num_classes)]
# print(len(all_boxes))
# raise SystemExit
_t = {'im_detect': time.time(), 'misc': time.time()}
model.eval()
empty_array = np.transpose(np.array([[],[],[],[],[]]), (1,0))
imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdbval_name, False)
imdb.competition_mode(on=True)
dataset = InferenceLoader(0, imdb, roidb, ratio_list, ratio_index, support_dir,
1, len(imdb._classes), num_shot=args.shot, training=False, normalize=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
data_iter = iter(dataloader)
supports = get_supports(imdb, args.sup_dir, args.shot)
for i in tqdm(range(num_images)):
scores_all = {}
boxes_all = {}
box_deltas_all = {}
pred_boxes_all = {}
data = next(data_iter)
with torch.no_grad():
im_data.resize_(data[0].size()).copy_(data[0])
im_info.resize_(data[1].size()).copy_(data[1])
gt_boxes.resize_(data[2].size()).copy_(data[2])
num_boxes.resize_(data[3].size()).copy_(data[3])
# support_ims_all.resize_(data[4].size()).copy_(data[4])
support_ims_all.resize_(supports.size()).copy_(supports)
# N, K = support_ims.shape[1:3]
# Hs, Ws = support_ims.shape[-2:]
# support_ims = support_ims.reshape(N*K,3,Hs,Ws)
# gt_classes = gt_boxes[0,:,4].unique().int().tolist()
det_tic = time.time()
# print(support_ims_all.shape)
with torch.no_grad():
for cls_ind, support_ims in enumerate(support_ims_all):
support_ims = support_ims.unsqueeze(0)
gt_class = cls_ind+1
rois, cls_prob, bbox_pred, \
rpn_loss_cls, rpn_loss_box, \
RCNN_loss_cls, RCNN_loss_bbox, \
rois_label = model(im_data, im_info, gt_boxes, num_boxes, support_ims)
scores_all[gt_class] = cls_prob.data
boxes_all[gt_class]= rois.data[:,:,1:5]
box_deltas_all[gt_class] = bbox_pred.data
det_toc = time.time()
detect_time = det_toc - det_tic
misc_tic = time.time()
# scores = cls_prob.data
# boxes = rois.data[:, :, 1:5]
# Apply bounding-box regression deltas
# box_deltas = bbox_pred.data
if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
# Optionally normalize targets by a precomputed mean and stdev
for cls_id, box_deltas in box_deltas_all.items():
box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
+ torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
box_deltas = box_deltas.view(1, -1, 4)
box_deltas_all[cls_id] = box_deltas
for cls_id, boxes in boxes_all.items():
box_deltas = box_deltas_all[cls_id]
scores = scores_all[cls_id]
pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)
pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)
# re-scale boxes to the origin img scale
pred_boxes /= data[1][0][2].item()
scores = scores.squeeze()
pred_boxes = pred_boxes.squeeze()
pred_boxes_all[cls_id] = pred_boxes
for cls_id, scores in scores_all.items():
# if j != gt_boxes[0, 0, 4]:
# all_boxes[j][i] = empty_array
# continue
# scores = scores_all[cls_id]
pred_boxes = pred_boxes_all[cls_id]
inds = torch.nonzero(scores[:,1]>thresh).view(-1)
if inds.numel() > 0:
cls_scores = scores[:,1][inds]
cls_boxes = pred_boxes[inds, :]
cls_dets = NMS(cls_boxes, cls_scores)
all_boxes[cls_id][i] = cls_dets.cpu().numpy()
else:
all_boxes[cls_id][i] = empty_array
misc_toc = time.time()
nms_time = misc_toc - misc_tic
# if args.imlog:
# origin_im = im_data[0].permute(1, 2, 0).contiguous().cpu().numpy()[:, :, ::-1]
# origin_im = origin_im - origin_im.min()
# origin_im /= origin_im.max()
# gt_im = origin_im.copy()
# pt_im = origin_im.copy()
# np_gt_boxes = gt_boxes[0]
# for n in range(np_gt_boxes.shape[0]):
# box = np_gt_boxes[n].clone()
# cv2.rectangle(gt_im, (box[0], box[1]), (box[2], box[3]), (0.1, 1, 0.1), 2)
# plt.imshow(gt_im)
# plt.show()
# sup_im = support_ims[0][0].permute(1, 2, 0).contiguous().cpu().numpy()[:, :, ::-1]
# sup_im = sup_im - sup_im.min()
# sup_im /= sup_im.max()
# plt.imshow(sup_im)
# plt.show()
# raise Exception(' ')
# raise Exception(' ')
# cv2.rectangle(im, (box[0], box[1]), (box[0] + box[2], box[1] + box[3]), (20, 255, 20), 2)
# tb_logger.write(i, gt, support_ims, predict, save_im=True)
sys.stdout.write('im_detect: {:d}/{:d} {:.3f}s {:.3f}s \r' \
.format(i + 1, num_images, detect_time, nms_time))
sys.stdout.flush()
output_dir = os.path.join(CWD, 'inference_output', args.eval_dir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
det_file = os.path.join(output_dir, 'detections.pkl')
with open(det_file, 'wb') as f:
pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
print('Evaluating detections')
imdb.evaluate_detections(all_boxes, output_dir)
Following changes have been made:
I have tested the performance in the setting of 3-shot without fine-tuning. This is the setting I interested in. Results are shown below:
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.025
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.050
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.022
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.018
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.031
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.038
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.138
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.275
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.293
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.151
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.300
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.440
The AR performance seems good, but the AP performance is not as good as the 1-way performance I tested before. Is this normal?
In Tab. 1 of the paper, the proposed DANA are compared with other works in the general detection evaluation settings. And the 18.6 COCO AP is really a great performance. However, the provided evaluation code seems only contains the 1-way evaluation mode, which is not a general setting for real application of object detection. And the BA block and the CISA block only support one-class inference. I have tried modifying the code to infer each image for multiple times and merged the results to get a multi-class inference results, like attention-RPN, but the poor performance are gained. Is there some references for this question? Thanks.