huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
31.56k stars 4.71k forks source link

new loss proposed: ASL #249

Closed mrT23 closed 3 years ago

mrT23 commented 3 years ago

we recently released a paper: "Asymmetric Loss For Multi-Label Classification". https://github.com/Alibaba-MIIL/ASL

in the paper we introduce a new loss function, ASL, which operates differently on positive samples (the object appears) and negative samples (the object does not appear). With the new loss, we reach SOTA results on all multi-label datasets. We also show improvement over focal loss in object detection, and single-label fine-grain classification.

@rwightman would you be interested in pull request to merge this loss into timm ?

rwightman commented 3 years ago

@mrT23 I'd be interested, although don't currently have the demo train/val scripts and dataset setup to support any open multi-label or fine-grained datasets ... was on my TODO list to support a few more demo datasets like OpenImages, CUB so perhaps this is good motiviation :) ImageNet isn't the best dataset to stress multi-label or class imbalance.

I have OpenImages image label support partially impl (based on my recent object detection support at https://github.com/rwightman/efficientdet-pytorch/tree/more_datasets). It could make sense to add that loss as an option to my efficientdet impl too. I was thinking of impl SeeSaw loss to see how it compared (https://arxiv.org/abs/2008.10032)

mrT23 commented 3 years ago

Regarding fine-grain classification: the dude here sent a run on Stanford-cars, and got a very good score on your repo, on the edge of SOTA

Regarding object-detection: ASL can serve there as drop-in replacement for the highly used focal loss, which we found is not optimal.

Regarding multi-label classification: i know it is not as popular as single-label classification, but to be honest, i don't know why. with single-label classification, you always assume that a picture contains one and only one object. this is an unrealistic assumption. multi-label classification is basically the simplest form of deep learning that can actually analyze "real-world" images.

Regarding OpenImages: the dataset is quite hard to "digest" (it's big, not all the download link work, it has "partial-labeling" methodology and more issues...). However, it's worth it, because you get a very strong analysis of the image content. we released a multi-label model of TResNet, trained on OpenImages. If you decide to tackle it and want some tips, feel free to contact.

anyway, i will open a pull request

rwightman commented 3 years ago

@mrT23 I spent a bit of time trying to get ASL working for EfficientDet object detection in place of Focal Loss. The experiment didn't work so well. The paper focuses on the imbalanced classification problem. Are the implementations here and in your original repo suited for object detection as is? Where additions or changes made to them for the COCO results mentioned in the paper?

mrT23 commented 3 years ago

@rwightman Both multi-label classification and Object detection are inherently imbalanced: all the classes are decoupled (sigmoid instead of softmax), and per class the number of "negatives" (the class does not appear in the image) is significantly higher than the number of "positives" (the class appear in the image).

I am not an expert in object detection. i took a variant of mm-detection repo, and only modified the loss there. however, two things I noticed there:

  1. the loss in object detection is complicated, since it is comprised of 3 different losses (class label, bounding box and IOU). every time you try to "sum-up" losses with different scales and objectives, you need to carefully balance (actually over-fit) their relative weights. once you modify one of the losses, you are in danger of loosing that delicate balance.
  2. focal loss default parameters are gamma_pos==gamma_neg=2, alpha=0.25. in my opinion, alpha=0.25 is a plain bug. it favors negative samples, despite the fact that they are already much more frequent

i could quite easily beat the basic detection score when changing the original focal loss (gamma_pos==gamma_neg=2, alpha=0.25) to simple ASL with gamma_pos=1, gamma_neg=2 and offcourse alpha=0.5. if you don't see improvement in this configuration, maybe try to adjust the overall learning a bit, or the relative weight of the label-loss.

Tal

crypdick commented 2 years ago

@mrT23 should timm get updated to use the optimized versions of ASL?