IDEA-Research / DN-DETR

[CVPR 2022 Oral] Official implementation of DN-DETR
Apache License 2.0
528 stars 57 forks source link

util/misc.py collate_fn函数 #22

Open qq747688898 opened 2 years ago

qq747688898 commented 2 years ago

`def collate_fn(batch):

import ipdb; ipdb.set_trace()

batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)

` 这里看到在nested_tensor_from_tensor_list中对batch[0],也就是训练图片的每个batch都做了像最大size进行padding的操作将一个batch的图片size保持一致,但是这里不需要对box进行修正吗?感觉box的cx cy w h还是用的修正前的图像坐标使用的?

FengLi-ust commented 2 years ago

Hey, box prediction is supposed to base on the images before batch padding. Therefore, we only need to modify boxes in augmentation. Thank you.

qq747688898 commented 2 years ago

I draw the target box on the sample and get this img, so it is normal?

企业微信20220623-174457@2x

the visualization code is:

  dataset_train = build_dataset(image_set='train', args=args)
  sampler_train = DynamicDetectorSampler(dataset_train, batch_size=args.batch_size)

  batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, 2, drop_last=True)

  data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
                                  collate_fn=utils.collate_fn, num_workers=args.num_workers)

  for samples, targets in data_loader_train:
    for i in range(len(targets)):
      mean = np.array([[[0.485, 0.456, 0.406]]])
      std  = np.array([[[0.229, 0.224, 0.225]]])
      img = np.transpose(samples.tensors[i].numpy(), (1, 2, 0)) * std + mean
      img = np.uint8(img * 255)
      img = np.stack([img[:, :, 2], img[:, :, 1], img[:, :, 0]], 2)
      h, w, c = img.shape
      boxes = targets[i]["boxes"].numpy().tolist()
      for box in boxes:
        cx, cy, width, height = box
        cx     = int(cx * w)
        width  = int(width * w)
        cy     = int(cy * h)
        height = int(height * h)
        img = cv2.rectangle(img, (cx - width // 2, cy - height // 2), (cx + width // 2, cy + height // 2), (0, 0, 255), 4)

      cv2.imshow("test", img)
      cv2.waitKey()
FengLi-ust commented 2 years ago

Hey, you can refer to class Postprocess() to see how to do post process. For example, we post process with img_h, img_w = target_sizes.unbind(1), in which img_h, img_w is the original size before padding from the target.