pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.05k stars 6.93k forks source link

FCOS empty box images #5266

Open barschiiii opened 2 years ago

barschiiii commented 2 years ago

I am playing around with new FCOS models (thanks for that) and am encountering issues when providing images without box annotations. This is a common use case in object detection, and also works for other detector models in torchvision.

A simple example to replicate:

model = fcos_resnet50_fpn(pretrained=True)
model(torch.zeros((1,3,512,512)), targets=[{"boxes": torch.empty(0,4), "labels": torch.empty(0,1).to(torch.int64)}])

An indexing error happens in FCOSHead when running compute_loss in this part:

all_gt_classes_targets = []
all_gt_boxes_targets = []
for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
    gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
    gt_classes_targets[matched_idxs_per_image < 0] = -1  # backgroud
    gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
    all_gt_classes_targets.append(gt_classes_targets)
    all_gt_boxes_targets.append(gt_boxes_targets)

A workaround seems to be necessary, when having empty targets. Happy for any guidance, maybe there is also a different way necessary for me to train on empty images.

@jdsgomes @xiaohu2015 @zhiqwang

Versions

Torchvision @ master

xiaohu2015 commented 2 years ago

@barschiiii I also found this issue, we plan to add another PR to handle this.

barschiiii commented 2 years ago

@xiaohu2015 that's great to hear, hope I will be able to try it soon :)

xiaohu2015 commented 2 years ago

@barschiiii hi, we fix this bug in https://github.com/pytorch/vision/pull/5267

barschiiii commented 2 years ago

@xiaohu2015 thanks - technically it works! However, I am running in NaN loss after short period of time. Not sure this is related, will try to explore. Using AMP mixed precision and Adam.

datumbox commented 2 years ago

Using AMP mixed precision and Adam.

Thanks for providing this info, it helps the debugging. Could you try @xiaohu2015 patch without them and let us know if it's fixed? Thanks!

barschiiii commented 2 years ago

@datumbox this is after applying the patch - do you mean running it without AMP?

datumbox commented 2 years ago

Yes, exactly. Using AMP+Adam without gradient clipping can cause instabilities. Running without them will tell us if the nans are caused by division by 0 or by some other instability.

barschiiii commented 2 years ago

I am actually running it with gradient clipping, but will try without AMP and also with SGD.

barschiiii commented 2 years ago

SGD with AMP, and Adam without AMP seem to both run fine. Adam with AMP and gradient clipping runs into instability (NaN) issues, also with different learning rates.

barschiiii commented 2 years ago

It might be related to the default initialization of anchor boxes:

if anchor_generator is None:
    anchor_sizes = ((8,), (16,), (32,), (64,), (128,))  # equal to strides of multi-level feature map
    aspect_ratios = ((1.0,),) * len(anchor_sizes)  # set only one anchor
    anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)

From my understanding, this should be lower if it should match the strides for e.g. a resnet50. Actually if I lower it, I am avoiding nan losses.

xiaohu2015 commented 2 years ago

@barschiiii there is no anchors in FCOS actually, but we borrow the anchor from retinanet. here, the anchor equal the cell or grid in the feature map, so the anchor size equals to strides of multi-level feature map. If you want to modify the label assignment, I think you can adjust the center_sampling_radius, or the lower bound and upper bound.

            # each anchor is only responsible for certain scale range.
            lower_bound = anchor_sizes * 4
            lower_bound[: num_anchors_per_level[0]] = 0
            upper_bound = anchor_sizes * 8
            upper_bound[-num_anchors_per_level[-1] :] = float("inf")

but I think, in most cases, you don't need do that.

barschiiii commented 2 years ago

But shouldn't the anchor size then be adjusted manually from the multi-level feature map stride? I can see in detectron2 implementation, that they are calculating the stride each time for a backbone that is passed. Here in this case the anchor size is hardcoded, and I am wondering how these hard-coded values have been decided, the default ones seem not right to me.

barschiiii commented 2 years ago

I might have been wrong, stability issues still happen. Will explore further but if anyone has an idea from where it could come, please let me know.

xiaohu2015 commented 2 years ago

But shouldn't the anchor size then be adjusted manually from the multi-level feature map stride? I can see in detectron2 implementation, that they are calculating the stride each time for a backbone that is passed. Here in this case the anchor size is hardcoded, and I am wondering how these hard-coded values have been decided, the default ones seem not right to me.

yes, the anchor size shouldn't be adjusted manually. as you see, we can get the strides of multi-level feature maps in backbone from the method output_shape. but torchvision does not implement the interface, so the default anchor size is hardcoded.

xiaohu2015 commented 2 years ago

I might have been wrong, stability issues still happen. Will explore further but if anyone has an idea from where it could come, please let me know.

Do you also test your datasests with detectron2? The stability issues can offen happen in detection models, maybe you should adjust the training hyparams.

barschiiii commented 2 years ago

Trying a lot of different settings, it seems the first forward pass in the classification head is causing NaNs and causing my instability issues. Could not resolve it for now, even if I force fp32 forward pass for this part.

datumbox commented 2 years ago

Can you confirm you still face the problem on the latest main branch?

Isalia20 commented 1 year ago

@barschiiii there is no anchors in FCOS actually, but we borrow the anchor from retinanet. here, the anchor equal the cell or grid in the feature map, so the anchor size equals to strides of multi-level feature map. If you want to modify the label assignment, I think you can adjust the center_sampling_radius, or the lower bound and upper bound.

            # each anchor is only responsible for certain scale range.
            lower_bound = anchor_sizes * 4
            lower_bound[: num_anchors_per_level[0]] = 0
            upper_bound = anchor_sizes * 8
            upper_bound[-num_anchors_per_level[-1] :] = float("inf")

but I think, in most cases, you don't need do that.

Not sure I understand why do we use anchors if FCOS doesn't need anchors. Can you explain a bit more?