facebookresearch / detr

End-to-End Object Detection with Transformers
Apache License 2.0
13.43k stars 2.42k forks source link

Suggestion - change to ResNeST50 backbone (new split attention arch) #29

Open lessw2020 opened 4 years ago

lessw2020 commented 4 years ago

One idea to jump DETR's impressive results might be to swap in the new ResNeST50 backbone (released last month by Amazon AI and UCDavis). In all of the architectures they tested, it immediately provided 3-4% AP boost for Coco.

This improvement also helps downstream tasks including object detection, instance segmentation and semantic segmentation. For example, by simply replace the ResNet-50 backbone with ResNeSt-50, we improve the mAP of Faster-RCNN on MS-COCO from 39.3% to 42.3% and the mIoU for DeeplabV3 on ADE20K from 42.1% to 45.1%.

It should plug and play right in. I've been using it for classification work and was a nice improvement there, and the concept of better global context maps to the improvements DETR is providing for the head architecture.

https://arxiv.org/abs/2004.08955v1 https://github.com/zhanghang1989/ResNeSt

(I plan to test this out on my own datasets, but will not have time to train it on Coco proper and I think conceptually it's a great match for DETR regardless).

raijinspecial commented 4 years ago

I have tried this and it works well, at least for one epoch on a machine with GPUs.

I don’t have and numbers yet as my main goal is to run this with torch xla, which also kind of works, at least in the sense of being able to complete a forward pass.

lessw2020 commented 4 years ago

Awesome, thanks for some initial feedback on putting resNeST to use with DETR!

I should also note that the official NeST currently won't export as JIT by default. Apparently the split architecture makes JIT think each split should have it's own BN. Fortunately @rwightman fixed this and I've verified that you can JIT script export NeST nicely with his version.

Here's a link his .py for that - JIT fix: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/resnest.py

and you can get the pretrained weights etc here from his TIMM model zoo (just use the create_model call): https://github.com/rwightman/pytorch-image-models

raijinspecial commented 4 years ago

No problemo!

You are right about the tracing issue, I followed the example in pytorch-image-models and adapted @rwightman's fixes to the original splat.py from zhanghang1989's repo as I'd found it a bit easier to install resnest that way, but either should work just fine.

You can have detr use resnest with a slight modification to the Backbone class, for example, using the zhanghang repo:

class Backbone(BackboneBase):
    """ResNet backbone with frozen BatchNorm."""
    def __init__(self, name: str,
                 train_backbone: bool,
                 return_interm_layers: bool,
                 dilation: bool):
        backbone = getattr(resnest.torch, name)(
            pretrained=True, norm_layer=FrozenBatchNorm2d)
        num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
        super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

It should be similar for the pytorch-image-models version as well.

rwightman commented 4 years ago

FYI, I am working towards a more consistent model creation interface / adapters / layer/bn freeze interface in timm to make all the models for suitable as backbones for various detect/segm applications. This codebase was one of my intended targets, along with my EfficientDet impl, and possibly detectron2/mmdetection. Right now currently getting lost down the rabbit hole of trying to get some other models to be jit script compatible....

lessw2020 commented 4 years ago

Looks like we should just go directly to the new FBNetv3 architecture - matches Nest50 but with 5x fewer FLOPS so much more performant:
resnest-fbnet3 https://arxiv.org/abs/2006.02049v1

rwightman commented 4 years ago

@lessw2020 FBNetV3 is another 'MBConvNet' (the model can be implemented in 10 extra lines in timm), it'll behave much like an EfficientNet but perhaps a little faster thanks to hard-swish and a few other differences. Don't pay much attention to the wide gap between ResNet family in terms of accuracy per FLOP. Heavy deptwise convs use means it'll be memory throughput bound on GPUs just like EfficientNets and you'll see lower practical throughput and higher GPU memory usage per FLOP. A 5x FLOP increase thus will not result in anything close to 5X img/sec increase or memory consumption reduction.

It'll be interesting to try applying their searched hparams for other models in timm ... I've been applying similar recipe of EMA weight avering + RMSProp (modified for stability), stochastic depth with AutoAugment/RandAugment, mixup for sometime now after exploring ideas from the EfficientNet papers. I can basically just plug their numbers in, which are a bit different than EfficientNet and MobileNetV3 defaults...

Also, quite curious if they actually used the PyTorch native RMSProp, it tends to blow up with training schemes like this... or perhaps the modified it to be closer to the TF impl like I did.

lessw2020 commented 4 years ago

@rwightman - thanks much for weighing in on this! I did see the MBConv architecture but didn't realize the memory ramifications you've pointed out so thanks much for the feedback. I'll stick with Nest for now then. EMA weight averaging does seem really promising and great to hear that you are exploring that space. One topic regarding augmentation - (now i can't find the paper...) but there was a new augmentation where they did cutmix but based on feature extraction with a CV adaptive background subtraction to ensure the cut image was relevant and not say background. That set new accuracy for ResNet models so worth exploring.

And regarding optimizers - coolMomentum looks promising: https://arxiv.org/abs/2005.14605v1

m-klasen commented 4 years ago

Hi, did you manage to improve your results with a resnest backbone? I've been struggling to achieve any meaningful results compared to the default r50-bb (yellow line). Maybe its just an issue of training with the different Bottleneck Conv's... image

lessw2020 commented 4 years ago

Hi @mlk1337 -planning to test that this week. Which nest version did you use and could I see your mods to the backbone function for hooking it in?
I hooked it in from HangZhangs torchhub for a nest50 but I see the bn is not frozen and wondering if you froze it or not for training since they freeze the default torch.resnet50?

lessw2020 commented 4 years ago

note - I'm training now and went ahead and froze the bn and it's weights which should be equivalent to the frozenbn in resnet.

m-klasen commented 4 years ago

Hi, I've tried to use both HangZhangs & rwightman's resnest version, modified the detr-build function to load it. But as of now i could not get it to outperform the default r50 as i am a few mAP short when converged. So far i couldn't figure out why, maybe some non-FrozenBatchNorm training is required.

lessw2020 commented 4 years ago

Hi @mlk1337 - thanks for sharing your results!
I did test and results were a bit in-conclusive. The mAP shot high at the start but ending was not very good, and when I ran some images it turned out basically every query had found the object with just a slightly diff bbox...so they all had confidence of .0001 when rounded. This is in contrast to same dataset with resnet where I would get the normal 1 or 2 high confidence detections and the other queries are no object. The one improvement was on some 'hard images' nest version detected vs regular did not, so that was the main improvement. Anyway, for now like you I am reverting to the default resnet. I would like to try unfrozen bn in the future but under time pressure so I'm just going to run with the default resnet atm. Note I should say this dataset was tiny so can't draw huge conclusions yet and the fact it did hit on ones the resnet version did not still implies it has some advantages, but it seems it will take more work than simply drop and go into detr.

ririya commented 4 years ago

I tried Resnest101 and EfficientNet and it did not outperform Resnet101 on my dataset. Resnet101-dc5 still outperforms all backbones.

munirfarzeen commented 3 years ago

hi, I would like to give my own pre-trained weights for the backbone. Where can I define it? how would it affect the rest of the network? how can I initialize the rest of the network?
backbone = getattr(resnest.torch, name)( pretrained=True, norm_layer=FrozenBatchNorm2d) what does getattr does? @raijinspecial how did you add the backbone? can you provide the code ?