shaunyuan22 / SODA-mmrotate

SODA-A Small Object Detection Toolbox and Benchmark
https://shaunyuan22.github.io/SODA/
Apache License 2.0
39 stars 6 forks source link

Changing some code can speed up evaluating in test time. #7

Closed Songyeyaosong closed 1 year ago

Songyeyaosong commented 1 year ago

What's the feature?

Hi, I've had the problems with long test time. So I changed some code to solve this problem. Hope this is useful to someone facing a same problem.

In mmrotate/datasets/sodaa.py at line 520 to line 522:

bboxes, scores = cls_dets[:, :-1], cls_dets[:, [-1]]
bboxes = torch.from_numpy(bboxes).to(torch.float32).contiguous()
scores = torch.from_numpy(np.squeeze(scores, 1)).to(torch.float32).contiguous()

I changed it to this:

bboxes = torch.from_numpy(bboxes).to(torch.float32).contiguous()
scores = torch.from_numpy(np.squeeze(scores, 1)).to(torch.float32).contiguous()
bboxes = bboxes.cuda()
scores = scores.cuda()
results, inds = nms_rotated(bboxes, scores, iou_thr)

This will use the gpu to speed up merging.

In mmrotate/datasets/sodda_eval/sodaa_eval.py at line 262 to line 265:

ious = box_iou_rotated(
    torch.from_numpy(np.array(d)).float(),
    torch.from_numpy(np.array(g)).float()).numpy()
return ious

I changed it to this:

ious = box_iou_rotated(
    torch.from_numpy(np.array(d)).float().cuda(),
    torch.from_numpy(np.array(g)).float().cuda()).cpu().numpy()
return ious

This will also use the gpu to speed up the iou computing.

And if you want to make it more faster, you can change the metric settings such as iouThrs and areRng to only evaluate the metrics that you want.

For example, if you set the areRng to [[0, 1000000]], which means all areas will be evaluated. At this time you can change the evaluateImg function to:

def evaluateImg(self, imgId, catId, aRng, maxDet):

    p = self.params
    if p.useCats:
        gt = self._gts[imgId, catId]
        dt = self._dts[imgId, catId]
    else:
        gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
        dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]

    have_gt = False
    have_dt = False
    if len(gt) == 0 and len(dt) == 0:
        return None
    if len(gt) > 0:
        have_gt = True
    if len(dt) > 0:
        have_dt = True

    gt_df = pd.DataFrame(gt)
    dt_df = pd.DataFrame(dt).head(maxDet)
    if have_gt:
        gt_df['_ignore'] = (gt_df['ignore'] | ((gt_df['area'] < aRng[0]) | (gt_df['area'] > aRng[1]))).astype(int)

    ious = self.ious[imgId, catId][:maxDet, :].copy()

    T = len(p.iouThrs)
    G = len(gt_df)
    D = len(dt_df)
    gtm = np.zeros((T, G))
    dtm = np.zeros((T, D))
    gtIg = np.array(gt_df['_ignore']) if have_gt else np.array([], dtype=np.int64)
    dtIg = np.zeros((T, D))

    if not len(ious) == 0:
        for tind, t in enumerate(p.iouThrs):

            if have_dt and have_gt:
                iou_max, inx = np.max(ious, axis=1), np.argmax(ious, axis=1)
                iou_mask = (iou_max >= t).astype(int)
                dtm[tind, :] = gt_df['id'][inx] * iou_mask

                dt_len = len(dtm[tind, :])
                unique_elements, first_indices = np.unique(dtm[tind, :], return_index=True)
                positive_indices = first_indices[unique_elements > 0]
                first_mask = np.zeros(dt_len, dtype=int)
                first_mask[positive_indices] = 1
                dtm[tind, :] = dtm[tind, :] * first_mask

                matched_dt_inx = np.where(dtm[tind, :] > 0)[0]
                matches_gt_inx = inx[matched_dt_inx]
                gtm[tind, matches_gt_inx] = dt_df['id'][matched_dt_inx]

    dtIds = dt_df['id'].to_list() if have_dt else []
    gtIds = gt_df['id'].to_list() if have_gt else []
    dtScores = dt_df['score'].to_list() if have_dt else []

    return {
        'image_id': imgId,
        'category_id': catId,
        'aRng': aRng,
        'maxDet': maxDet,
        'dtIds': dtIds,
        'gtIds': gtIds,
        'dtMatches': dtm,
        'gtMatches': gtm,
        'dtScores': dtScores,
        'gtIgnore': gtIg,
        'dtIgnore': dtIg,
    }

Because at this time all the bboxes and gts can be used to evaluate, so you can discard the for loop and use numpy and pandas to speed up the evaluation.

I also find that using 1 nproc is the fastest to evaluate results. And when using 1 nproc, evaluating can be done in a very short amount of time like 10 seconds.

But if you want to use multiple nproc to evaluate results, considering that the source code does not implement multiple nproc for iou computation, you can make these changes.

In mmrotate/datasets/sodda_eval/sodaa_eval.py at line 179 to line 180:

self.ious = {(imgId, catId): computeIoU(imgId, catId)
             for imgId in p.imgIds for catId in catIds}

If you want to use mutiple nproc to evaluate, you can change it to this:

if self.nproc > 1:

    computeIoU = partial(self.computeIoU)

    img_cat_lst = [[imgId, catId] for imgId in p.imgIds for catId in catIds]
    imgId_lst, catId_lst = [], []
    for lst in img_cat_lst:
        imgId_lst.append(lst[0])
        catId_lst.append(lst[1])

    pool = Pool(self.nproc)
    iou_lst = pool.starmap(computeIoU, zip(imgId_lst, catId_lst))
    pool.close()

    for img_cat, iou in zip(img_cat_lst, iou_lst):

        tmp_dict = {tuple(img_cat): iou}
        self.ious.update(tmp_dict)
else:
    self.ious = {(imgId, catId): computeIoU(imgId, catId)
                 for imgId in p.imgIds for catId in catIds}

Hope this will be helpful for you :)

Any other context?

No response

shaunyuan22 commented 1 year ago

We greatly appreciate your insightful recommendations for enhancing the project's code. We are committed to incorporating your valuable suggestions by making updates to the code to bolster the speed and stability of testing procedures. Again, we extend our heartfelt gratitude for your constructive insights.

HuQ1an commented 10 months ago

@Songyeyaosong hello!Thanks for providing some codes for speed up the test phase. However, when I try to follow your codes about speed up the iou computing, some error report:

"File "/home/anaconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/ops/box_iou_rotated.py", line 144, in box_iou_rotated ext_module.box_iou_rotated( RuntimeError: CUDA error: invalid configuration argument"

Do you encouter this issue before?

HuQ1an commented 10 months ago

@Songyeyaosong hello!Thanks for providing some codes for speed up the test phase. However, when I try to follow your codes about speed up the iou computing, some error report:

"File "/home/anaconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/ops/box_iou_rotated.py", line 144, in box_iou_rotated ext_module.box_iou_rotated( RuntimeError: CUDA error: invalid configuration argument"

Do you encouter this issue before?

This error happens after replacing

ious = box_iou_rotated( torch.from_numpy(np.array(d)).float(), torch.from_numpy(np.array(g)).float()).numpy() return ious

to

ious = box_iou_rotated( torch.from_numpy(np.array(d)).float().cuda(), torch.from_numpy(np.array(g)).float().cuda()).cpu().numpy() return ious

Songyeyaosong commented 10 months ago

@Songyeyaosong hello!Thanks for providing some codes for speed up the test phase. However, when I try to follow your codes about speed up the iou computing, some error report: "File "/home/anaconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/ops/box_iou_rotated.py", line 144, in box_iou_rotated ext_module.box_iou_rotated( RuntimeError: CUDA error: invalid configuration argument" Do you encouter this issue before?

This error happens after replacing

ious = box_iou_rotated( torch.from_numpy(np.array(d)).float(), torch.from_numpy(np.array(g)).float()).numpy() return ious

to

ious = box_iou_rotated( torch.from_numpy(np.array(d)).float().cuda(), torch.from_numpy(np.array(g)).float().cuda()).cpu().numpy() return ious

Hello! I did not encouter this issue before. Maybe you can try my pytorch and python version. My version is python3.8 and pytorch1.17.1+cu117 :). I'm sorry I can't provide some help on this issue.

Christy99cc commented 9 months ago

@Songyeyaosong hello!Thanks for providing some codes for speed up the test phase. However, when I try to follow your codes about speed up the iou computing, some error report: "File "/home/anaconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/ops/box_iou_rotated.py", line 144, in box_iou_rotated ext_module.box_iou_rotated( RuntimeError: CUDA error: invalid configuration argument" Do you encouter this issue before?

This error happens after replacing ious = box_iou_rotated( torch.from_numpy(np.array(d)).float(), torch.from_numpy(np.array(g)).float()).numpy() return ious to ious = box_iou_rotated( torch.from_numpy(np.array(d)).float().cuda(), torch.from_numpy(np.array(g)).float().cuda()).cpu().numpy() return ious

Hello! I did not encouter this issue before. Maybe you can try my pytorch and python version. My version is python3.8 and pytorch1.17.1+cu117 :). I'm sorry I can't provide some help on this issue.

@Songyeyaosong Hello! I have searched the PyTorch official website(https://pytorch.org/get-started/previous-versions/) for the specific version you mentioned, but unfortunately, I couldn't find it. If you have any alternative suggestions or sources where I might locate this version, please let me know. Your assistance in this matter is greatly appreciated. Thank you!

Songyeyaosong commented 7 months ago

@Songyeyaosong hello!Thanks for providing some codes for speed up the test phase. However, when I try to follow your codes about speed up the iou computing, some error report: "File "/home/anaconda3/envs/openmmlab/lib/python3.9/site-packages/mmcv/ops/box_iou_rotated.py", line 144, in box_iou_rotated ext_module.box_iou_rotated( RuntimeError: CUDA error: invalid configuration argument" Do you encouter this issue before?

This error happens after replacing ious = box_iou_rotated( torch.from_numpy(np.array(d)).float(), torch.from_numpy(np.array(g)).float()).numpy() return ious to ious = box_iou_rotated( torch.from_numpy(np.array(d)).float().cuda(), torch.from_numpy(np.array(g)).float().cuda()).cpu().numpy() return ious

Hello! I did not encouter this issue before. Maybe you can try my pytorch and python version. My version is python3.8 and pytorch1.17.1+cu117 :). I'm sorry I can't provide some help on this issue.

@Songyeyaosong Hello! I have searched the PyTorch official website(https://pytorch.org/get-started/previous-versions/) for the specific version you mentioned, but unfortunately, I couldn't find it. If you have any alternative suggestions or sources where I might locate this version, please let me know. Your assistance in this matter is greatly appreciated. Thank you!

try this: pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html

ShiZican commented 3 months ago

Hello, thank you for your code! When I used it, I found that the evaluation results were quite different from those provided in the soda paper. Does this happen to you? for example, when i use oriented-rcnn, get Aps {'mAP_AP': 0.398, 'mAP_AP_50': 0.813, 'mAP_AP_75': 0.324, 'mAP_AP_eS': 0.635, 'mAP_AP_rS': 0.59, 'mAP_AP_gS': 0.555, 'mAP_AP_Normal': 0.636, 'mAP_mAP_copypaste': '0.398 0.813 0.324 0.635 0.590 0.555 0.636 '}, but in soda paper reports APs{34.4 70.7 28.6 12.5 28.6 44.5 36.7}

liuaj22 commented 2 months ago

Hello, thank you for your code! When I used it, I found that the evaluation results were quite different from those provided in the soda paper. Does this happen to you? for example, when i use oriented-rcnn, get Aps {'mAP_AP': 0.398, 'mAP_AP_50': 0.813, 'mAP_AP_75': 0.324, 'mAP_AP_eS': 0.635, 'mAP_AP_rS': 0.59, 'mAP_AP_gS': 0.555, 'mAP_AP_Normal': 0.636, 'mAP_mAP_copypaste': '0.398 0.813 0.324 0.635 0.590 0.555 0.636 '}, but in soda paper reports APs{34.4 70.7 28.6 12.5 28.6 44.5 36.7}

Hi, did you apply any other tricks, such as multi-scale training/testing and data augmentation, to get a mAP of 39.8? And why did you get better performance for eS objects than gS objects?

ShiZican commented 2 months ago

Hello, thank you for your code! When I used it, I found that the evaluation results were quite different from those provided in the soda paper. Does this happen to you? for example, when i use oriented-rcnn, get Aps {'mAP_AP': 0.398, 'mAP_AP_50': 0.813, 'mAP_AP_75': 0.324, 'mAP_AP_eS': 0.635, 'mAP_AP_rS': 0.59, 'mAP_AP_gS': 0.555, 'mAP_AP_Normal': 0.636, 'mAP_mAP_copypaste': '0.398 0.813 0.324 0.635 0.590 0.555 0.636 '}, but in soda paper reports APs{34.4 70.7 28.6 12.5 28.6 44.5 36.7}

Hi, did you apply any other tricks, such as multi-scale training/testing and data augmentation, to get a mAP of 39.8? And why did you get better performance for eS objects than gS objects?

No, I just run orient-rcnn provided in SODA code, and change the evaluate code with "Songyeyaosong" provide (in this issue). I found the question and i asked for help.