fundamentalvision / Deformable-DETR

Deformable DETR: Deformable Transformers for End-to-End Object Detection.
Apache License 2.0
3.15k stars 513 forks source link

self.class_embed = nn.Linear(hidden_dim, num_classes) #152

Closed Innary closed 1 year ago

Innary commented 2 years ago

不好意思想问一下为什么 deformable-DETR中类别分类是 self.class_embed = nn.Linear(hidden_dim, num_classes) num_classes=91 在这里是91类,而在标准的DETR中是92类(91类+1背景类)? 是做了什么映射操作么

EMU1337X commented 2 years ago

同问

cdluminate commented 2 years ago

原版DETR用的cross entropy loss, 因为COCO自身只有一些N/A类别,不带背景类(用于表示该prediction head所对应的物体什么都不是),所以增加背景类扩展为了92。 Deformable DETR用的focal loss (binary cross entropy),这种情况下因为label可以是zero-hot vector, 不需要扩展到92类。

Pujin823 commented 1 year ago

所以只要在one-hot编码中对应全0的都会被视为背景类别吗,因为coco只有80类,但最只是最高的类别索引为90,中间一些索引并不对应类别(比如索引0),这些经过one-hot之后都是全0的吧