WongKinYiu / YOLO

An MIT rewrite of YOLOv9
MIT License
608 stars 64 forks source link

BoxMatcher fails when targets is empty. #87

Closed Abdul-Mukit closed 3 weeks ago

Abdul-Mukit commented 2 months ago

Describe the bug

I encountered a situation where I had batch size of 1 (last image of the epoch), and due to the augmentations the target was empty or of size (1,0,5) = [batch_size, numb_targers, 5]. That caused this line to fail: https://github.com/WongKinYiu/YOLO/blob/8228669808a626fc5f9c233fdb35550b5e041fae/yolo/utils/bounding_box_utils.py#L221 The target_matrix here had the shape [1, 0, 8400].

Can you please suggest what should be the solution? I found this: https://github.com/Nioolek/PPYOLOE_pytorch/blob/41c9928124bf705cbc7eb37fdcbf34ebf85bb456/ppyoloe/assigner/tal_assigner.py#L76

# negative batch
        if num_max_boxes == 0:
            assigned_labels = torch.full([batch_size, num_anchors], bg_index)
            assigned_bboxes = torch.zeros([batch_size, num_anchors, 4])
            assigned_scores = torch.zeros(
                [batch_size, num_anchors, num_classes])
            return assigned_labels, assigned_bboxes, assigned_scores

Should we something like that to the BoxMatcher.__call__()? https://github.com/WongKinYiu/YOLO/blob/8228669808a626fc5f9c233fdb35550b5e041fae/yolo/utils/bounding_box_utils.py#L224

Please suggest what should be the solution. Thank you again for your time @henrytsui000.

Abdul-Mukit commented 2 months ago

Tried out this solution like this for BoxMatcher.__call__() Worked. Will be making a PR shortly. Thank you. Will appreciate your feedback.

def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
        predict_cls, predict_bbox = predict
        n_targets = target.shape[1]
        if n_targets == 0:
            device = predict_bbox.device
            align_cls = torch.zeros_like(predict_cls, device=device)
            align_bbox = torch.zeros_like(predict_bbox, device=device)
            valid_mask = torch.zeros(predict_cls.shape[:2], dtype=bool, device=device)
            return torch.cat([align_cls, align_bbox], dim=-1), valid_mask