Open Vallum opened 2 years ago
Hey, thanks for your interest in our work. We are glad to see that you implement DINO based on DN-DETR. Is there any difference between your re-implemented DINO and our DINO? Can it achieve the same performance as our DINO? Maybe you could provide more information and we can discuss with more details.
Fist of all, with Resnet-50 + Deformable DETR + DN+DINO, 36 epochs
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.509
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.691
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.556
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.337
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.542
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.653
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.380
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.658
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.730
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.570
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.772
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.880
I think the performance is almost same with your official DINO.
For MQS, I wanted to maintain the choice between original Deformable DETR's 2 stage variants and your other variants,
if self.two_stage:
output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
# hack implementation for two-stage Deformable DETR
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals
topk = self.two_stage_num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
topk_coords_unact = topk_coords_unact.detach()
reference_points = topk_coords_unact.sigmoid()
# MQS is dab + mqs
if self.use_mqs:
assert self.use_dab
reference_points_mqs = reference_points
# sometimes the target is empty, add a zero part of query_embed to avoid unused parameters
reference_points_mqs += self.tgt_embed.weight[0][0]*torch.tensor(0).cuda()
tgt_mqs = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model
# query_embed is not None when training.
if query_embed is not None:
reference_points_dab = query_embed[..., self.d_model:].sigmoid()
tgt_dab = query_embed[..., :self.d_model]
reference_points = torch.cat([reference_points_dab, reference_points_mqs], dim=1)
tgt = torch.cat([tgt_dab, tgt_mqs], dim=1)
else:
reference_points = reference_points_mqs
tgt = tgt_mqs
else:
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact, self.d_model)))
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
else:
if self.use_dab:
reference_points = query_embed[..., self.d_model:].sigmoid()
tgt = query_embed[..., :self.d_model]
# tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
else:
query_embed, tgt = torch.split(query_embed, c, dim=1)
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_embed).sigmoid()
# bs, num_quires, 2
init_reference_out = reference_points
For negative sampling, the Official DINO pushed the negative sampling out of dn_components.py, in which case, it was hard for me to use both DN-DETR and DINO simutaneously. So I implemented double sampling inside dn_components.py
# in def prepare_for_dn
if training:
if contrastive:
new_targets = []
for t in targets:
new_t = {}
new_t['labels'] = torch.cat([t['labels'], torch.tensor(len(t['labels']) * [num_classes], dtype=torch.int64).cuda()], dim=0)
new_t['boxes'] = torch.cat([t['boxes'], t['boxes']], dim=0)
new_targets.append(new_t)
targets = new_targets
known = [(torch.ones_like(t['labels'])).cuda() for t in targets] # [ [ 1, 1], [1, 1, 1], ... ]
know_idx = [torch.nonzero(t) for t in known] # [ [0, 1], [0, 1, 2], ... ]
known_num = [sum(k) for k in known] # [ 2, 3, ... ]
With this implementation, I parsed the performance in every module step from DN-DETR to DN+DINO | source | resnet-50 | epochs | AP | AP50 | AP75 | APS | APM | APL |
---|---|---|---|---|---|---|---|---|---|
paper | dino-MQS-LFT-4scale | 12 | 47.9 | 65.3 | 52.1 | 31.2 | 50.9 | 61.9 | |
paper | dino-MQS-LFT-5scale | 12 | 48.3 | 65.8 | 52.4 | 32.2 | 51.3 | 62.2 | |
paper | DN-DDETR-4scale | 12 | 43.4 | 61.9 | 47.2 | 24.8 | 46.8 | 59.4 | |
self | DN-DDETR-MQS-4scale | 12 | 48.2 | 66.0 | 52.6 | 29.9 | 51.4 | 63.0 | |
self | DN-DDETR-MQS-LFT-4scale | 12 | 48.1 | 65.3 | 52.4 | 30.4 | 51.3 | 62.7 | |
self | DN-DDETR-CDN-MQS-LFT-4scale | 12 | 48.2 | 65.4 | 52.5 | 31.1 | 51.1 | 63.3 |
source | resnet-50 | epochs | AP | AP50 | AP75 | APS | APM | APL |
---|---|---|---|---|---|---|---|---|
paper | dino-MQS-LFT-4scale | 36 | 50.5 | 68.3 | 55.1 | 32.7 | 53.9 | 64.9 |
paper | dino-MQS-LFT-5scale | 36 | 51.0 | 69.0 | 55.6 | 34.1 | 53.6 | 65.6 |
paper | DN-DDETR-4scale | 50 | 48.6 | 67.4 | 52.7 | 31.0 | 52.0 | 63.7 |
self | DN-DDETR-MQS-4scale | 36 | 49.9 | 68.2 | 54.0 | 34.6 | 53.2 | 64.6 |
self | DN-DDETR-MQS-LFT-4scale | 36 | 50.3 | 68.2 | 55.0 | 32.7 | 53.3 | 65.1 |
self | DN-DDETR-CDN-MQS-LFT-4scale | 36 | 50.7 | 68.7 | 55.4 | 33.2 | 54.1 | 65.2 |
That's great. Looks like you are comparing with the old DINO in your table. You can initialize the parameters as the new DINO, so you can achieve around 49.0 in 12epoch.
@FengLi-ust Could you let me know which parameters in newer DINO are different from older DINO? I followed older paper settings, but I cannot find which one is different. For me, their parameters look just same?!
For my latest 12 epochs setting, I found
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.487
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.664
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.530
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.309
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.520
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631
It looked similar, but the small area AP is comparatively lower. (0.309 vs. 0.32) So I am not sure that they have same or similar settings.
Namespace(amp=False, aux_loss=True, backbone='resnet50', backbone_freeze_keywords=None,
batch_norm_type='FrozenBatchNorm2d', batch_size=2, bbox_loss_coef=5, box_noise_scale=0.4,
clip_max_norm=0.1, cls_loss_coef=1, coco_panoptic_path=None, coco_path='data/coco', contrastive=True,
dataset_file='coco', debug=False, dec_layers=6, dec_n_points=4, device='cuda',
dice_loss_coef=1, dilation=False, dim_feedforward=2048,
dist_backend='nccl', dist_url='env://', distributed=True, drop_lr_now=False,
dropout=0.0, enc_layers=6, enc_n_points=4, eos_coef=0.1, epochs=12, eval=False,
find_unused_params=False, finetune_ignore=None, fix_size=False, focal_alpha=0.25,
frozen_weights=None, giou_loss_coef=2, gpu=0, hidden_dim=256,
label_noise_scale=0.5, local_rank=0, lr=0.0001, lr_backbone=1e-05, lr_drop=10,
mask_loss_coef=1, masks=False, modelname='dn_dab_deformable_detr', nheads=8, note='',
num_feature_levels=4, num_patterns=0, num_queries=900, num_results=300, num_select=300,
num_workers=10, output_dir='exps/r50_dn_dab_deformable_detr_two_stage_refactor_12epochs',
pe_temperatureH=20, pe_temperatureW=20, position_embedding='sine', pre_norm=False,
pretrain_model_path=None, random_refpoints_xy=False, rank=0, remove_difficult=False,
return_interm_layers=False,
save_checkpoint_interval=10, save_log=False, save_results=False, scalar=200, seed=42,
set_cost_bbox=5, set_cost_class=2,set_cost_giou=2,
start_epoch=0, transformer_activation='relu', tsst=False, two_stage=True,
use_dn=True, weight_decay=0.0001, world_size=8)
I just have a quick look. The lr_drop should be set to 11.
@Vallum Hey, you can pull request so I can merge your code. We can also have discussions if you meet problems.
@FengLi-ust Thank you for the response. Let me just check the result and do the code prepared.
@Vallum Thank you for your nice work. I just want to ask a tiny question.
In the code to choose Top K proposals from encoder output class, you wrote the code like below
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
But, I think enc_outputs_class[..., 0]
is only considering class0. (class number 0 could be different according to a Dataset)
In my opinion, topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1]
could be better which can choose proposal with considering all classes. (refer to DINO)
I want to hear you opinion and thanks in advance :)
Hi, authors,
Thank you for opening your fantastic project.
I was very impressed on your successive project DN-DETR and DINO,
so I have merged DINO component to this precedent Deformable DETR based DN-DETR, which is a little bit different from official-DINO.
Do you authors, by any chance, interested in to merge DINO into this DN-DETR?
If so, please let me know and prepare the code sharing. Because you already have your own official DINO repo, maybe you don't want to mix DN-DETR with another DINO code, That's ok, and in that case, I am considering to take another way to open my implementation
Thanks.