megvii-research / AnchorDETR

An official implementation of the Anchor DETR.
Other
337 stars 36 forks source link

question about extended to panoptic segmentation #15

Closed luohao123 closed 2 years ago

luohao123 commented 3 years ago

Hi, AnchorDetr massively fixed detr problem on slow converge speed as well as boosted AP on small.

However, notice that AnchorDETR using 900 as output candidates (num queries), this might not a problem in detection, but once extend it into panoptic segmentation, 900 become a problem. The mainly reason are:

  1. Panoptic in detr treat stuff and things all in things, which means, it will output a gaint tensor: 900x400x400 for example for every possible instances and semantic seg in image, that is a very huge tensor if input resolution not small;
  2. 900 might accelerate training on detection, but not for instance seg or semantic seg.

Does there any exp on this or how to make it possible do panoptic seg in anchordetr way? Any suggestion would be very appreciated.

tangjiuqi097 commented 3 years ago

@luohao123 Hi,

For the instance segmentation task, you can select the queries that matched with ground-truth to predict instance masks in training, and select the queries with top100 confidence scores in inference. Unfortunately, I don't know about the panoptic segmentation task, but you may also try to select some queries.

luohao123 commented 3 years ago

@tangjiuqi097 I recently try train AnchorDETR vs DETR, when I compare 2 of them, I found AnchorDETR converge speed is not as fast as it should be, here is the screenshot:

For AnchorDETR:

image

at iter 249999, AP only 29

while DETR on same iterations:

image

at iter 22999, AP get 32, but noticable AnchorDETR get a higer AP on small even whole AP not higher than DETR.

Do u know why this happen? I am using extactly same training hyper params as DETR in lr, max_iter, warmup etc, namely, I am using extactly config in d2 ymal format:

SOLVER:
  # AMP:
  # ENABLED: true
  IMS_PER_BATCH: 56
  BASE_LR: 0.0001
  STEPS: (369600,)
  MAX_ITER: 554400
  WARMUP_FACTOR: 1.0
  WARMUP_ITERS: 10
  WEIGHT_DECAY: 0.0001
  OPTIMIZER: "ADAMW"
  BACKBONE_MULTIPLIER: 0.1
  CLIP_GRADIENTS:
    ENABLED: True
    CLIP_TYPE: "full_model"
    # CLIP_TYPE: "norm"
    CLIP_VALUE: 0.01
    NORM_TYPE: 2.0
INPUT:
  MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832)
  CROP:
    ENABLED: True
    TYPE: "absolute_range"
    SIZE: (384, 600)
  FORMAT: "RGB"
TEST:
  EVAL_PERIOD: 10000
DATALOADER:
  FILTER_EMPTY_ANNOTATIONS: False
  NUM_WORKERS: 2
VERSION: 2

Do u know why this happen?

tangjiuqi097 commented 3 years ago

@luohao123 Hi, could you provide more detail about the difference with our code?

luohao123 commented 3 years ago

@tangjiuqi097 Hi, all model code are same as anchor_detr.py, configrations are same, only different is I am using d2 for dataloader and evaluation, the lr schedualer I inherited from detr d2 version.

Does anchorDetr has big gap in lr schedualer compare with detr?

tangjiuqi097 commented 3 years ago

@luohao123 Could you provide the log.txt file? I think the lr scheduler will not have such a large effect.

luohao123 commented 3 years ago

log.txt here is the log resumed from 70000.

tangjiuqi097 commented 3 years ago

@luohao123 Hi, there are some hyper-parameters that you should pay attention to, i.e., DROPOUT 0.0, DIM_FEEDFORWARD 1024. And you may scale up the BASE_LR with a factor range in [1,7], e.g. BASE_LR 0.00025, as you use 7 times batchsize. But I think the hyper-parameters are not the reason that your performance is far away from our code.

I notice that you have modified the predicted categories from 91 to 81. image It's OK but you should make sure the num_class in SetCriterion has the same value. Otherwise, it will cause a bug that leads to significant performance degradation. I am not sure if you meet this problem, as the default NUM_CLASS in detr is 80 and I cannot find the setting in the log file you provided.

luohao123 commented 3 years ago

@tangjiuqi097 thank u! thank u for your analysis. Here is more info:

  1. Some params I used (all same as DETR, and added some params needed by AnchorDETR):
cfg.MODEL.DETR.NHEADS = 8
cfg.MODEL.DETR.DROPOUT = 0.1
cfg.MODEL.DETR.DIM_FEEDFORWARD = 2048
cfg.MODEL.DETR.ENC_LAYERS = 6
cfg.MODEL.DETR.DEC_LAYERS = 6
cfg.MODEL.DETR.PRE_NORM = False
cfg.MODEL.DETR.HIDDEN_DIM = 256
cfg.MODEL.DETR.NUM_OBJECT_QUERIES = 100
cfg.MODEL.DETR.FROZEN_WEIGHTS = ''
cfg.MODEL.DETR.NUM_FEATURE_LEVELS = 1 # can be 3 tambien
# for AnchorDETR
cfg.MODEL.DETR.NUM_QUERY_POSITION = 300
cfg.MODEL.DETR.NUM_QUERY_PATTERN = 3
cfg.MODEL.DETR.SPATIAL_PRIOR = 'learned'

Does DIM_FEEDFORWARD should be 2048 -> 1024? and DROP_OUT should be 0.1 -> 0? Batchsize is same as DETR, biggest 64 batchsize on 8 GPU 3090.

  1. Classes I changed SetCritrion as well:
self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
....
transformer = Transformer(
            num_classes=self.num_classes+1,
            d_model=hidden_dim,
            dropout=dropout,
            nhead=nheads,
            num_feature_levels=num_feature_levels,
            dim_feedforward=dim_feedforward,
            num_encoder_layers=enc_layers,
            num_decoder_layers=dec_layers,
            activation="relu",
            num_query_position=num_query_position,
            num_query_pattern=num_query_pattern,
            spatial_prior=spatial_prior,
        )
....
self.criterion = SetCriterion(
            self.num_classes,
            matcher=matcher,
            weight_dict=weight_dict,
            eos_coef=no_object_weight,
            losses=losses,
        )

And the rest of code remain unchanged, since Transformer part I copied directly from AnchorDETR, only changes is added a num_classes as params.

Does the 1 reason is the mainly reason? Since I using 80 classes train DETR as well, and it works OK.

tangjiuqi097 commented 3 years ago

@luohao123 Hi, the num_class in Transformer and SetCriterion should have the same value, i.e.,

self.criterion = SetCriterion(
            self.num_classes+1,
            matcher=matcher,
            weight_dict=weight_dict,
            eos_coef=no_object_weight,
            losses=losses,
        )

I think the performance will be much better if you fix it.

luohao123 commented 3 years ago

@tangjiuqi097 Hi, it seems I added +1 still can be train, I am trying it now. Why less 1 effect result? Does means we treat background wrongly with one of the right class?

tangjiuqi097 commented 3 years ago

@luohao123 Hi, you can refer to these codes. The onehot target of focal loss should not include the background class.

luohao123 commented 3 years ago

@tangjiuqi097 Hi, add added numclasses + 1, and now it doesn't converge:

image

normaly first 10000 iteration should be mAP 5

tangjiuqi097 commented 3 years ago

@luohao123 Hi, could you provide more detail? I can not find the reason based on your message.

luohao123 commented 3 years ago

@tangjiuqi097 We found your SetSrions not same as detr original, I used original one, So it might caused by the misslignment between them.

In detr, the definition are:

self.criterion = SetCriterion(
            self.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses,
        )

in anchordetr:

criterion = SetCriterion(num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha)

Seems simplifed something.

However, when I removed original SetCriterion and using all exactly same as AnchorDETR SetCriterion. I found the loss is quit abnormal:

iter: 4  total_loss: 5491  loss_ce: 1.131  loss_bbox: 2.05  loss_giou: 1.639  loss_ce_0: 1.022  loss_bbox_0: 2.05  loss_giou_0: 1.639  loss_ce_1: 1.045  loss_bbox_1: 2.05  loss_giou_1: 1.639  loss_ce_2: 1.074  loss_bbox_2: 2.05  loss_giou_2: 1.639  loss_ce_3: 1.1  loss_bbox_3: 2.05  loss_giou_3: 1.639  loss_ce_4: 1.12  loss_bbox_4: 2.05  loss_giou_4: 1.639  time: 0.8473  data_time: 2.5367  lr: 0.0001  max_mem: 10269M

the total_loss is every high at the begainning, (it should be about 139 normally). And when I logged out all loss dictionary, I found the

{'loss_ce': tensor(1.0737, device='cuda:1', grad_fn=<MulBackward0>), 'class_error': tensor(100., device='cuda:1'), 'loss_bbox': tensor(1.7441, device='cuda:1', grad_fn=<MulBackward0>), 'loss_giou': tensor(1.7019, device='cuda:1', grad_fn=<MulBackward0>), 'cardinality_error': tensor(893.3000, device='cuda:1'), 'loss_ce_0': tensor(1.0307, device='cuda:1', grad_fn=<MulBackward0>), 'loss_bbox_0': tensor(1.7441, device='cuda:1', grad_fn=<MulBackward0>), 'loss_giou_0': tensor(1.7019, device='cuda:1', grad_fn=<MulBackward0>), 'cardinality_error_0': tensor(893.3000, device='cuda:1'), 'loss_ce_1': tensor(1.0356, device='cuda:1', grad_fn=<MulBackward0>), 'loss_bbox_1': tensor(1.7441, device='cuda:1', grad_fn=<MulBackward0>), 'loss_giou_1': tensor(1.7019, device='cuda:1', grad_fn=<MulBackward0>), 'cardinality_error_1': tensor(893.3000, device='cuda:1'), 'loss_ce_2': tensor(1.0469, device='cuda:1', grad_fn=<MulBackward0>), 'loss_bbox_2': tensor(1.7441, device='cuda:1', grad_fn=<MulBackward0>), 'loss_giou_2': tensor(1.7019, device='cuda:1', grad_fn=<MulBackward0>), 'cardinality_error_2': tensor(893.3000, device='cuda:1'), 'loss_ce_3': tensor(1.0575, device='cuda:1', grad_fn=<MulBackward0>), 'loss_bbox_3': tensor(1.7441, device='cuda:1', grad_fn=<MulBackward0>), 'loss_giou_3': tensor(1.7019, device='cuda:1', grad_fn=<MulBackward0>), 'cardinality_error_3': tensor(893.3000, device='cuda:1'), 'loss_ce_4': tensor(1.0655, device='cuda:1', grad_fn=<MulBackward0>), 'loss_bbox_4': tensor(1.7441, device='cuda:1', grad_fn=<MulBackward0>), 'loss_giou_4': tensor(1.7019, device='cuda:1', grad_fn=<MulBackward0>), 'cardinality_error_4': tensor(893.3000, device='cuda:1')}

the cardinality_error_4 are very high...

I think cardinality_error_4 doesn't included into backpropergation, but this value is very abnormal caused total_loss very high.

And I believe, it caused by class misalignment (class number) when calculate cardinality_error:

@torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

What could be wrong?

tangjiuqi097 commented 3 years ago

@luohao123 Yes, the class_error and cardinality_error will not be included in backpropagation, so that they will not affect the performance. These codes are inherited from deformable detr, and I do not pay attention to these two values as they are not important. You can remove them to find out the true total_loss. For example, you can modify these lines to:

        if self.Training:
            gt_instances = [x["instances"].to(self.device) for x in batched_inputs]

            targets = self.prepare_targets(gt_instances)
            loss_dict = self.criterion(output, targets)
            weight_dict = self.criterion.weight_dict
            valid_loss_dict = {}
            for k in loss_dict.keys():
                if k in weight_dict:
                    valid_loss_dict[k] = loss_dict[k] * weight_dict[k]
            return valid_loss_dict

The total loss will be about 30 after 100 iterations.

I think you can try the code with the updated SetCriterion. BTW, note that the weight for focal loss is 2 but not 1 in DETR.

luohao123 commented 3 years ago

@tangjiuqi097 Thanks for your advice, I omitted cardinality error and class error,

and now, the question becomes .... loss too small ...

[11/02 16:27:03 d2.utils.events]:  eta: 4 days, 22:24:55  iter: 99  total_loss: 19.9  loss_ce: 0.8344  loss_bbox: 1.132  loss_giou: 1.345  loss_ce_0: 0.8952  loss_bbox_0: 1.115  loss_giou_0: 1.364  loss_ce_1: 0.8575  loss_bbox_1: 1.107  loss_giou_1: 1.356  loss_ce_2: 0.8427  loss_bbox_2: 1.109  loss_giou_2: 1.348  loss_ce_3: 0.8363  loss_bbox_3: 1.11  loss_giou_3: 1.344  loss_ce_4: 0.8341  loss_bbox_4: 1.112  loss_giou_4: 1.341  time: 0.7858  data_time: 0.1661  lr: 0.0001  max_mem: 11866M
[11/02 16:27:19 d2.utils.events]:  eta: 4 days, 22:44:34  iter: 119  total_loss: 19.38  loss_ce: 0.824  loss_bbox: 1.075  loss_giou: 1.327  loss_ce_0: 0.8514  loss_bbox_0: 1.074  loss_giou_0: 1.342  loss_ce_1: 0.8272  loss_bbox_1: 1.055  loss_giou_1: 1.331  loss_ce_2: 0.8197  loss_bbox_2: 1.055  loss_giou_2: 1.326  loss_ce_3: 0.8175  loss_bbox_3: 1.058  loss_giou_3: 1.327  loss_ce_4: 0.8197  loss_bbox_4: 1.064  loss_giou_4: 1.326  time: 0.7845  data_time: 0.1615  lr: 0.0001  max_mem: 11866M
[11/02 16:27:35 d2.utils.events]:  eta: 4 days, 22:47:06  iter: 139  total_loss: 18.88  loss_ce: 0.7957  loss_bbox: 1.029  loss_giou: 1.322  loss_ce_0: 0.8278  loss_bbox_0: 1.014  loss_giou_0: 1.333  loss_ce_1: 0.8019  loss_bbox_1: 1.01  loss_giou_1: 1.323  loss_ce_2: 0.7921  loss_bbox_2: 1.005  loss_giou_2: 1.322  loss_ce_3: 0.7914  loss_bbox_3: 1.012  loss_giou_3: 1.323  loss_ce_4: 0.7943  loss_bbox_4: 1.021  loss_giou_4: 1.32  time: 0.7844  data_time: 0.1710  lr: 0.0001  max_mem: 11866M

now the loss become 19 iter 100.

I am using lr 0.0001 on 8 GPU with batchsize 80, and fp16 training.

what's the loss_ce_1 loss_ce_2 mean btw?

and why the cardinality error so high since it just another aspect of class error, isn't it?

tangjiuqi097 commented 3 years ago

@luohao123 Hi, you can check the loss weight again and I am not sure if it is normal in your setting.

loss too small

The cardinality error doesn't make sense for the focal loss and I will remove it in near future.

luohao123 commented 3 years ago

@tangjiuqi097 Here is my loss weight settings:

weight_dict = {"loss_ce": 1, "loss_bbox": l1_weight}
        weight_dict["loss_giou"] = giou_weight
        if deep_supervision:
            aux_weight_dict = {}
            for i in range(dec_layers - 1):
                aux_weight_dict.update(
                    {k + f"_{i}": v
                     for k, v in weight_dict.items()})
            weight_dict.update(aux_weight_dict)
        losses = ["labels", "boxes", "cardinality"]
        if self.mask_on:
            losses += ["masks"]
        self.criterion = SetCriterion(
            self.num_classes+1,
            matcher=matcher,
            weight_dict=weight_dict,
            # eos_coef=no_object_weight,
            losses=losses,
        )

where, iou and cls weights:

 GIOU_WEIGHT: 2.0
    L1_WEIGHT: 5.0

does there any differences than anchordetr?

luohao123 commented 3 years ago

@tangjiuqi097 PS:

I just modified loss_ce 1 -> 2, and now loss like this:

[11/02 17:23:51 d2.utils.events]:  eta: 4 days, 22:02:07  iter: 99  total_loss: 24.36  loss_ce: 1.623  loss_bbox: 1.09  loss_giou: 1.336  loss_ce_0: 1.686  loss_bbox_0: 1.077  loss_giou_0: 1.358  loss_ce_1: 1.62  loss_bbox_1: 1.072  loss_giou_1: 1.361  loss_ce_2: 1.6  loss_bbox_2: 1.067  loss_giou_2: 1.354  loss_ce_3: 1.597  loss_bbox_3: 1.067  loss_giou_3: 1.347  loss_ce_4: 1.606  loss_bbox_4: 1.077  loss_giou_4: 1.338  time: 0.7693  data_time: 0.1579  lr: 0.0001  max_mem: 11724M
[11/02 17:24:08 d2.utils.events]:  eta: 4 days, 21:20:20  iter: 119  total_loss: 23.48  loss_ce: 1.578  loss_bbox: 1.009  loss_giou: 1.313  loss_ce_0: 1.625  loss_bbox_0: 1.036  loss_giou_0: 1.35  loss_ce_1: 1.574  loss_bbox_1: 1.018  loss_giou_1: 1.345  loss_ce_2: 1.57  loss_bbox_2: 1.001  loss_giou_2: 1.331  loss_ce_3: 1.566  loss_bbox_3: 1.001  loss_giou_3: 1.322  loss_ce_4: 1.567  loss_bbox_4: 1.002  loss_giou_4: 1.315  time: 0.7781  data_time: 0.1539  lr: 0.0001  max_mem: 11724M
[11/02 17:24:23 d2.utils.events]:  eta: 4 days, 21:44:43  iter: 139  total_loss: 23.14  loss_ce: 1.527  loss_bbox: 0.9891  loss_giou: 1.33  loss_ce_0: 1.583  loss_bbox_0: 0.9865  loss_giou_0: 1.356  loss_ce_1: 1.534  loss_bbox_1: 0.9735  loss_giou_1: 1.339  loss_ce_2: 1.523  loss_bbox_2: 0.9813  loss_giou_2: 1.334  loss_ce_3: 1.522  loss_bbox_3: 0.9808  loss_giou_3: 1.335  loss_ce_4: 1.523  loss_bbox_4: 0.9787  loss_giou_4: 1.335  time: 0.7777  data_time: 0.1524  lr: 0.0001  max_mem: 11724M

is that normal?

luohao123 commented 3 years ago

@tangjiuqi097 It worked bro, converge speed so fast, one night training we can achieve 24 AP which same time DETR roughly only get 10.

image

github-actions[bot] commented 2 years ago

This issue is not active for a long time and it will be closed in 5 days. Feel free to re-open it if you have further concerns.