There are several common situations in the reimplementation issues as below
Reimplement a model in the model zoo using the provided configs
Reimplement a model in the model zoo on other dataset (e.g., custom datasets)
Reimplement a custom model but all the components are implemented in MMDetection
Reimplement a custom model with new modules implemented by yourself
There are several things to do for different cases as below.
For case 1 & 3, please follow the steps in the following sections thus we could help to quick identify the issue.
For case 2 & 4, please understand that we are not able to do much help here because we usually do not know the full code and the users should be responsible to the code they write.
One suggestion for case 2 & 4 is that the users should first check whether the bug lies in the self-implemented code or the original code. For example, users can first make sure that the same model runs well on supported datasets. If you still need help, please describe what you have done and what you obtain in the issue, and follow the steps in the following sections and try as clear as possible so that we can better help you.
Checklist
I have searched related issues but cannot get the expected help.
The issue has not been fixed in the latest version.
Describe the issue
I apply QFL on ATSSHead, but loss values always come to 0
self.sampling = False
if self.train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
# SSD sampling=False so use PseudoSampler
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.num_dcn = num_dcn
self.with_attn = with_attn
def _init_layers(self):
"""Initialize layers of the head."""
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=dict(
type='DCN',
deform_groups=1
) if i == 0 else self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=dict(
type='DCN',
deform_groups=1
) if i == 0 else self.conv_cfg,
norm_cfg=self.norm_cfg))
assert self.num_anchors == 1, 'anchor free version'
self.se_cls = nn.Conv2d(
self.feat_channels, self.cls_out_channels, 3, padding=1)
self.se_reg = nn.Conv2d(
self.feat_channels, 4, 3, padding=1)
self.scales = nn.ModuleList(
[Scale(1.0) for _ in self.prior_generator.strides])
def forward(self, feats):
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually a tuple of classification scores and bbox prediction
cls_scores (list[Tensor]): Classification and quality (IoU)
joint scores for all scale levels, each is a 4D-tensor,
the channel number is num_classes.
bbox_preds (list[Tensor]): Box distribution logits for all
scale levels, each is a 4D-tensor, the channel number is
4*(n+1), n is max value of integral set.
"""
return multi_apply(self.forward_single, feats, self.scales)
def forward_single(self, x, scale):
"""Forward feature of a single scale level.
Args:
x (Tensor): Features of a single scale level.
scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
the bbox prediction.
Returns:
tuple:
cls_score (Tensor): Cls and quality joint scores for a single
scale level the channel number is num_classes.
bbox_pred (Tensor): Box distribution logits for a single scale
level, the channel number is 4*(n+1), n is max value of
integral set.
"""
cls_feat = x
reg_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
reg_feat = reg_conv(reg_feat)
cls_score = self.se_cls(cls_feat)
bbox_pred = scale(self.se_reg(reg_feat)).float()
return cls_score, bbox_pred
def anchor_center(self, anchors):
"""Get anchor centers from anchors.
Args:
anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format.
Returns:
Tensor: Anchor centers with shape (N, 2), "xy" format.
"""
anchors_cx = (anchors[..., 2] + anchors[..., 0]) / 2
anchors_cy = (anchors[..., 3] + anchors[..., 1]) / 2
return torch.stack([anchors_cx, anchors_cy], dim=-1)
def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
bbox_targets, stride, num_total_samples):
"""Compute loss of a single scale level.
Args:
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
cls_score (Tensor): Cls and quality joint scores for each scale
level has shape (N, num_classes, H, W).
bbox_pred (Tensor): Box distribution logits for each scale
level with shape (N, 4*(n+1), H, W), n is max value of integral
set.
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
bbox_targets (Tensor): BBox regression targets of each anchor
weight shape (N, num_total_anchors, 4).
stride (tuple): Stride in this scale level.
num_total_samples (int): Number of positive samples that is
reduced over all GPUs.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert stride[0] == stride[1], 'h stride is not equal to w stride!'
anchors = anchors.reshape(-1, 4)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1, 4)
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
score = label_weights.new_zeros(labels.shape)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
weight_targets = cls_score.detach().sigmoid()
weight_targets = weight_targets.max(dim=1)[0][pos_inds]
pos_decode_bbox_pred = self.bbox_coder.decode(
pos_anchors, pos_bbox_pred)
score[pos_inds] = bbox_overlaps(
pos_decode_bbox_pred.detach(),
pos_bbox_targets,
is_aligned=True)
# regression loss
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_bbox_targets,
weight=weight_targets,
avg_factor=1.0)
else:
loss_bbox = bbox_pred.sum() * 0
weight_targets = bbox_pred.new_tensor(0)
# cls (qfl) loss
loss_cls = self.loss_cls(
cls_score, (labels, score),
weight=label_weights,
avg_factor=num_total_samples)
return loss_cls, loss_bbox, weight_targets.sum()
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Cls and quality scores for each scale
level has shape (N, num_classes, H, W).
bbox_preds (list[Tensor]): Box distribution logits for each scale
level with shape (N, 4*(n+1), H, W), n is max value of integral
set.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (list[Tensor] | None): specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, img_metas, device=device)
label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels)
if cls_reg_targets is None:
return None
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = reduce_mean(
torch.tensor(num_total_pos, dtype=torch.float,
device=device)).item()
num_total_samples = max(num_total_samples, 1.0)
losses_cls, losses_bbox, \
avg_factor = multi_apply(
self.loss_single,
anchor_list,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
self.prior_generator.strides,
num_total_samples=num_total_samples)
avg_factor = sum(avg_factor)
avg_factor = reduce_mean(avg_factor).clamp_(min=1).item()
losses_bbox = list(map(lambda x: x / avg_factor, losses_bbox))
return dict(
loss_cls=losses_cls, loss_bbox=losses_bbox)
def _get_bboxes_single(self,
cls_score_list,
bbox_pred_list,
score_factor_list,
mlvl_priors,
img_meta,
cfg,
rescale=False,
with_nms=True,
**kwargs):
"""Transform outputs of a single image into bbox predictions.
Args:
cls_score_list (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_priors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas from
all scale levels of a single image, each item has shape
(num_priors * 4, H, W).
score_factor_list (list[Tensor]): Score factor from all scale
levels of a single image. GFL head does not need this value.
mlvl_priors (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid, has shape
(num_priors, 4).
img_meta (dict): Image meta info.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
tuple[Tensor]: Results of detected bboxes and labels. If with_nms
is False and mlvl_score_factor is None, return mlvl_bboxes and
mlvl_scores, else return mlvl_bboxes, mlvl_scores and
mlvl_score_factor. Usually with_nms is False is used for aug
test. If with_nms is True, then return the following format
- det_bboxes (Tensor): Predicted bboxes with shape \
[num_bboxes, 5], where the first 4 columns are bounding \
box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
column are scores between 0 and 1.
- det_labels (Tensor): Predicted labels of the corresponding \
box with shape [num_bboxes].
"""
cfg = self.test_cfg if cfg is None else cfg
img_shape = img_meta['img_shape']
nms_pre = cfg.get('nms_pre', -1)
mlvl_bboxes = []
mlvl_scores = []
mlvl_labels = []
for level_idx, (cls_score, bbox_pred, stride, priors) in enumerate(
zip(cls_score_list, bbox_pred_list,
self.prior_generator.strides, mlvl_priors)):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
assert stride[0] == stride[1]
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
scores = cls_score.permute(1, 2, 0).reshape(
-1, self.cls_out_channels).sigmoid()
# After https://github.com/open-mmlab/mmdetection/pull/6268/,
# this operation keeps fewer bboxes under the same `nms_pre`.
# There is no difference in performance for most models. If you
# find a slight drop in performance, you can set a larger
# `nms_pre` than before.
results = filter_scores_and_topk(
scores, cfg.score_thr, nms_pre,
dict(bbox_pred=bbox_pred, priors=priors))
scores, labels, _, filtered_results = results
bbox_pred = filtered_results['bbox_pred']
priors = filtered_results['priors']
bboxes = self.bbox_coder.decode(
self.anchor_center(priors), bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_labels.append(labels)
return self._bbox_post_process(
mlvl_scores,
mlvl_labels,
mlvl_bboxes,
img_meta['scale_factor'],
cfg,
rescale=rescale,
with_nms=with_nms)
def get_targets(self,
anchor_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
label_channels=1,
unmap_outputs=True):
"""Get targets for GFL head.
This method is almost the same as `AnchorHead.get_targets()`. Besides
returning the targets as the parent method does, it also returns the
anchors as the first element of the returned tuple.
"""
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
num_level_anchors_list = [num_level_anchors] * num_imgs
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
anchor_list[i] = torch.cat(anchor_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single,
anchor_list,
valid_flag_list,
num_level_anchors_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
img_metas,
label_channels=label_channels,
unmap_outputs=unmap_outputs)
# no valid anchors
if any([labels is None for labels in all_labels]):
return None
# sampled anchors of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
# split targets to a list w.r.t. multiple levels
anchors_list = images_to_levels(all_anchors, num_level_anchors)
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights,
num_level_anchors)
return (anchors_list, labels_list, label_weights_list,
bbox_targets_list, bbox_weights_list, num_total_pos,
num_total_neg)
def _get_target_single(self,
flat_anchors,
valid_flags,
num_level_anchors,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
img_meta,
label_channels=1,
unmap_outputs=True):
"""Compute regression, classification targets for anchors in a single
image.
Args:
flat_anchors (Tensor): Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors, 4)
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
num_level_anchors Tensor): Number of anchors of each scale level.
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
img_meta (dict): Meta info of the image.
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: N is the number of total anchors in the image.
anchors (Tensor): All anchors in the image with shape (N, 4).
labels (Tensor): Labels of all anchors in the image with shape
(N,).
label_weights (Tensor): Label weights of all anchor in the
image with shape (N,).
bbox_targets (Tensor): BBox targets of all anchors in the
image with shape (N, 4).
bbox_weights (Tensor): BBox weights of all anchors in the
image with shape (N, 4).
pos_inds (Tensor): Indices of positive anchor with shape
(num_pos,).
neg_inds (Tensor): Indices of negative anchor with shape
(num_neg,).
"""
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg.allowed_border)
if not inside_flags.any():
return (None, ) * 7
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
num_level_anchors_inside = self.get_num_level_anchors_inside(
num_level_anchors, inside_flags)
assign_result = self.assigner.assign(anchors, num_level_anchors_inside,
gt_bboxes, gt_bboxes_ignore,
gt_labels)
sampling_result = self.sampler.sample(assign_result, anchors,
gt_bboxes)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
bbox_weights = torch.zeros_like(anchors)
labels = anchors.new_full((num_valid_anchors, ),
self.num_classes,
dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
if gt_labels is None:
# Only rpn gives gt_labels as None
# Foreground is the first class
labels[pos_inds] = 0
else:
labels[pos_inds] = gt_labels[
sampling_result.pos_assigned_gt_inds]
if self.train_cfg.pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg.pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_anchors.size(0)
anchors = unmap(anchors, num_total_anchors, inside_flags)
labels = unmap(
labels, num_total_anchors, inside_flags, fill=self.num_classes)
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (anchors, labels, label_weights, bbox_targets, bbox_weights,
pos_inds, neg_inds)
def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
split_inside_flags = torch.split(inside_flags, num_level_anchors)
num_level_anchors_inside = [
int(flags.sum()) for flags in split_inside_flags
]
return num_level_anchors_inside
Basically, I copy from gfl_head.py but changed the part where they use DistributionFocalLoss, specifically in `_init_layers`, `forward_single`, `loss_single` and `bbox_coder` param
4. What dataset did you use? Synthetic Fruit dataset: https://public.roboflow.com/object-detection/synthetic-fruit/1
**Environment**
1. Please run `python mmdet/utils/collect_env.py` to collect necessary environment information and paste it here.
sys.platform: win32
Python: 3.8.12 (default, Oct 12 2021, 03:01:40) [MSC v.1916 64 bit (AMD64)]
CUDA available: True
GPU 0: NVIDIA GeForce GTX 1050
CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0
NVCC: Not Available
GCC: n/a
PyTorch: 1.11.0
PyTorch compiling details: PyTorch built with:
C++ Version: 199711
MSVC 192829337
Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
2. You may add addition that may be helpful for locating the problem, such as
1. How you installed PyTorch \[e.g., pip, conda, source\] conda
2. Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)
**Results**
If applicable, paste the related results here, e.g., what you expect and what you get.
Result from my implementation of QFL on ATSSHead
![image](https://user-images.githubusercontent.com/69593462/173114470-3fdf59e3-c000-4955-bacf-e94e116d9f1e.png)
Result from GFLHead which includes QFL and DFL
![image](https://user-images.githubusercontent.com/69593462/173116010-48908d7a-846c-4d85-a8f5-cb5e08e42048.png)
**Issue fix**
If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!
Notice
There are several common situations in the reimplementation issues as below
There are several things to do for different cases as below.
Checklist
Describe the issue
I apply QFL on ATSSHead, but loss values always come to 0
Reproduction
from mmdet.core import (anchor_inside_flags, bbox_overlaps, build_assigner, build_sampler, images_to_levels, multi_apply, reduce_mean, unmap) from mmdet.core.utils import filter_scores_and_topk from ..builder import HEADS, build_loss from .anchor_head import AnchorHead
@HEADS.register_module() class SEHead(AnchorHead): def init(self, num_classes, in_channels, stacked_convs=4, num_dcn=1, with_attn=False, conv_cfg=None, norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), kwargs): self.stacked_convs = stacked_convs self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg super(SEHead, self).init( num_classes, in_channels, kwargs)
sys.platform: win32 Python: 3.8.12 (default, Oct 12 2021, 03:01:40) [MSC v.1916 64 bit (AMD64)] CUDA available: True GPU 0: NVIDIA GeForce GTX 1050 CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0 NVCC: Not Available GCC: n/a PyTorch: 1.11.0 PyTorch compiling details: PyTorch built with:
TorchVision: 0.12.0 OpenCV: 4.5.5 MMCV: 1.4.7 MMCV Compiler: MSVC 192930140 MMCV CUDA Compiler: 11.3 MMDetection: 2.24.1+157623a
Process finished with exit code 0