Open AlbertoSabater opened 3 years ago
I have finally managed to extract logits along with the detections. To do so, I had to perform NMS to all the detections at the same time without splitting by class. Since your method regresses one bounding box for each class given a detection, I had to choose the class with higher confidence before selecting its associated bounding box. This is the final code I modified in case if someone is interested in:
def filter_results_v2(self, boxlist, num_classes):
"""Returns bounding-box detection results by thresholding on scores and
applying non-maximum suppression (NMS).
"""
boxes = boxlist.bbox.reshape(-1, num_classes * 4)
scores = boxlist.get_field("scores").reshape(-1, num_classes)
device = scores.device
result = []
# Apply threshold on detection probabilities and apply NMS
best_scores, best_class_inds = torch.max(scores[:, 1:], 1)
inds_v2 = best_scores > self.score_thresh
best_scores = best_scores[inds_v2]
best_logits = scores[inds_v2, 1:]
best_class_inds = best_class_inds[inds_v2]
boxes = boxes[inds_v2]
if len(boxes) > 0:
boxes = torch.stack([ boxes[i, best_class_inds[ind] * 4 : (best_class_inds[ind] + 1) * 4] for i,ind in enumerate(range(len(boxes))) ])
else: boxes = torch.full((0,4), 4, dtype=torch.float32, device=device)
boxlist_for_class = BoxList(boxes, boxlist.size, mode="xyxy")
boxlist_for_class.add_field("scores", best_scores)
boxlist_for_class.add_field("labels", best_class_inds)
boxlist_for_class.add_field("logits", best_logits)
boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms)
result = boxlist_for_class
number_of_detections = len(result)
# Limit to max_per_image detections **over all classes**
if number_of_detections > self.detections_per_img > 0:
cls_scores = result.get_field("scores")
image_thresh, _ = torch.kthvalue(
cls_scores.cpu(), number_of_detections - self.detections_per_img + 1
)
keep = cls_scores >= image_thresh.item()
keep = torch.nonzero(keep).squeeze(1)
result = result[keep]
return result
Code below belongs to this file.
Hi! When performing object detection inference, I would like to return all the class confidence vectors, not only the best score/class for each bbox. Can you provide me some intuition about the files I should modify to get them?
Thank you in advance, Alberto