Open ouening opened 3 years ago
This is a good idea. However, you didn't apply the regression to the anchors quite right in your code.
I can't train a network right now, the code below should work. Replaces losses.py
(I borrowed the author's implementation):
import numpy as np
import torch
import torch.nn as nn
def calc_iou(a, b):
area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])
iw = torch.clamp(iw, min=0)
ih = torch.clamp(ih, min=0)
ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
ua = torch.clamp(ua, min=1e-8)
intersection = iw * ih
IoU = intersection / ua
return IoU
def bbox_overlaps_giou(bboxes1, bboxes2):
rows = bboxes1.shape[0]
cols = bboxes2.shape[0]
ious = torch.zeros((rows, cols))
if rows * cols == 0:
return ious
exchange = False
if bboxes1.shape[0] > bboxes2.shape[0]:
bboxes1, bboxes2 = bboxes2, bboxes1
ious = torch.zeros((cols, rows))
exchange = True
area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (
bboxes1[:, 3] - bboxes1[:, 1])
area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (
bboxes2[:, 3] - bboxes2[:, 1])
inter_max_xy = torch.min(bboxes1[:, 2:],bboxes2[:, 2:])
inter_min_xy = torch.max(bboxes1[:, :2],bboxes2[:, :2])
out_max_xy = torch.max(bboxes1[:, 2:],bboxes2[:, 2:])
out_min_xy = torch.min(bboxes1[:, :2],bboxes2[:, :2])
inter = torch.clamp((inter_max_xy - inter_min_xy), min=0)
inter_area = inter[:, 0] * inter[:, 1]
outer = torch.clamp((out_max_xy - out_min_xy), min=0)
outer_area = outer[:, 0] * outer[:, 1]
union = area1+area2-inter_area
closure = outer_area
ious = inter_area / union - (closure - union) / closure
ious = torch.clamp(ious,min=-1.0,max = 1.0)
if exchange:
ious = ious.T
return ious
class BBoxRegress:
def __init__(self, mean=None, std=None):
if mean is None:
if torch.cuda.is_available():
self.mean = torch.from_numpy(np.array([0, 0, 0, 0]).astype(np.float32)).cuda()
else:
self.mean = torch.from_numpy(np.array([0, 0, 0, 0]).astype(np.float32))
else:
self.mean = mean
if std is None:
if torch.cuda.is_available():
self.std = torch.from_numpy(np.array([1,1,1,1]).astype(np.float32)).cuda()
else:
self.std = torch.from_numpy(np.array([1,1,1,1]).astype(np.float32))
else:
self.std = std
def __call__(self, boxes, deltas):
widths = boxes[:, :, 2] - boxes[:, :, 0]
heights = boxes[:, :, 3] - boxes[:, :, 1]
ctr_x = boxes[:, :, 0] + 0.5 * widths
ctr_y = boxes[:, :, 1] + 0.5 * heights
dx = deltas[:, :, 0] * self.std[0] + self.mean[0]
dy = deltas[:, :, 1] * self.std[1] + self.mean[1]
dw = deltas[:, :, 2] * self.std[2] + self.mean[2]
dh = deltas[:, :, 3] * self.std[3] + self.mean[3]
pred_ctr_x = ctr_x + dx * widths
pred_ctr_y = ctr_y + dy * heights
pred_w = torch.exp(dw) * widths
pred_h = torch.exp(dh) * heights
pred_boxes_x1 = pred_ctr_x - 0.5 * pred_w
pred_boxes_y1 = pred_ctr_y - 0.5 * pred_h
pred_boxes_x2 = pred_ctr_x + 0.5 * pred_w
pred_boxes_y2 = pred_ctr_y + 0.5 * pred_h
pred_boxes = torch.stack([pred_boxes_x1, pred_boxes_y1, pred_boxes_x2, pred_boxes_y2], dim=2)
return pred_boxes
class FocalLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, classifications, regressions, anchors, annotations):
self.bb_transform = BBoxRegress()
alpha = 0.25
gamma = 2.0
batch_size = classifications.shape[0]
classification_losses = []
regression_losses = []
anchor = anchors[0, :, :]
anchor_widths = anchor[:, 2] - anchor[:, 0]
anchor_heights = anchor[:, 3] - anchor[:, 1]
anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths
anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights
for j in range(batch_size):
classification = classifications[j, :, :]
regression = regressions[j, :, :]
bbox_annotation = annotations[j, :, :]
bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]
classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
if bbox_annotation.shape[0] == 0:
if torch.cuda.is_available():
alpha_factor = torch.ones(classification.shape).cuda() * alpha
alpha_factor = 1. - alpha_factor
focal_weight = classification
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
bce = -(torch.log(1.0 - classification))
# cls_loss = focal_weight * torch.pow(bce, gamma)
cls_loss = focal_weight * bce
classification_losses.append(cls_loss.sum())
regression_losses.append(torch.tensor(0).float())
else:
alpha_factor = torch.ones(classification.shape) * alpha
alpha_factor = 1. - alpha_factor
focal_weight = classification
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
bce = -(torch.log(1.0 - classification))
# cls_loss = focal_weight * torch.pow(bce, gamma)
cls_loss = focal_weight * bce
classification_losses.append(cls_loss.sum())
regression_losses.append(torch.tensor(0).float())
continue
IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) # num_anchors x num_annotations
IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1
# compute the loss for classification
targets = torch.ones(classification.shape) * -1
if torch.cuda.is_available():
targets = targets.cuda()
targets[torch.lt(IoU_max, 0.4), :] = 0
positive_indices = torch.ge(IoU_max, 0.5)
num_positive_anchors = positive_indices.sum()
assigned_annotations = bbox_annotation[IoU_argmax, :]
targets[positive_indices, :] = 0
targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
if torch.cuda.is_available():
alpha_factor = torch.ones(targets.shape).cuda() * alpha
else:
alpha_factor = torch.ones(targets.shape) * alpha
alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))
cls_loss = focal_weight * bce
if torch.cuda.is_available():
cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())
else:
cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape))
classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))
# compute the loss for regression
if positive_indices.sum() > 0:
transformed_bbox = self.bb_transform(anchor[positive_indices, :].unsqueeze(0), regression[positive_indices, :].unsqueeze(0))[0, :, :]
diou_loss = 1 - bbox_overlaps_giou(transformed_bbox, assigned_annotations[positive_indices, :4])
regression_losses.append(diou_loss.mean())
else:
if torch.cuda.is_available():
regression_losses.append(torch.tensor(0).float().cuda())
else:
regression_losses.append(torch.tensor(0).float())
return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)
@yhenon Thank you! I'll try it.
Tell me how it goes, I might be interested in merging this
Tell me how it goes, I might be interested in merging this
The code works, but the result is worse than smooth l1 (same configuration except for reg loss).
Interesting, it might be necessary to change the relative weighting of the regression vs classification. Let me look at that
Thanks!
Hi @yhenon did you check the appropriate ratio of classification vs regression loss?
Any update here ?
Hi yhenon, thanks for your open source codes of retiannet. In my experiments, the result is much better than https://github.com/fizyr/keras-retinanet. I sincerely hope that you can implement the other loss type in classification loss or regression loss. I tried giou loss but failed. My code is (modify
losses.py
):I would be grateful if you can help, thank you!