open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
29.57k stars 9.46k forks source link

dist_train keep waiting when filter_empty_gt=False #2193

Closed panjianning closed 4 years ago

panjianning commented 4 years ago

Environment:

My config file: cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py, with only the ann path modified.

  1. When i set filter_empty_gt=False (20% of my images are background only.), there are not any error messages, but the trainning process are always waiting..., so I have to keyboardinterupt it:

image

  1. When i set filter_empty_gt=True, every thing is ok.
ZwwWayne commented 4 years ago

Hi @PanJianning , Just to confirm the problem, you are training cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py on your own dataset, but meet bug when setting filter_empty_gt=False? Do you implement the logic to deal with empty gt in your customed dataset? The bug might result from your implementation.

panjianning commented 4 years ago

Hi @PanJianning , Just to confirm the problem, you are training cascade_rcnn_dconv_c3-c5_r50_fpn_1x.py on your own dataset, but meet bug when setting filter_empty_gt=False? Do you implement the logic to deal with empty gt in your customed dataset? The bug might result from your implementation.

Thanks for your reply. I use CocoDataset and change nothing except the annotation path. BTW, Single-gpu trainning is ok when filter_empty_gt=False

panjianning commented 4 years ago

@ZwwWayne I print gt_labels in forward_train method of cascade_rcnn detector, and notice that the trainning hangs when a batch contains only background images.

image

panjianning commented 4 years ago

Update

My NCCL version is 2.5.7. After setting the following env variable, the dead lock disapeared, but I got negative loss...

export NCCL_LL_THRESHOLD=0

@ZwwWayne @hellock When a batch contains only background images, the return losses of cascade rcnn misses keys s_i.bbox_loss, i=0,1,..,num_stages-1 , this is a expected behavior, since we can't do bbox regression when there are no gt_bboxes. But It causes my distributed trainning process hanging.

image

After I add this row in cascade_rcnn.py, the dead lock disapeared. But will it mess up the backpropagation?

image

clw5180 commented 4 years ago

I also tried filter_empty_gt=False but it seems not work. Hope you can get a better score in the leaderboard. @PanJianning

panjianning commented 4 years ago

Maybe it has something to do with this issue and this line

clw5180 commented 4 years ago

Maybe it has something to do with this issue and this line

I can add negative sample to train, and I'm sure the training dataset becomes more, but the mAP seems just soso. @PanJianning

panjianning commented 4 years ago

Maybe it has something to do with this issue and this line

I can add negative sample to train, and I'm sure the training dataset becomes more, but the mAP seems just soso. @PanJianning

I only meet the issue when I use distributed trainning. It gives me a better score in first round so I always set it to false in the second round.

zoushun commented 4 years ago

@hellock When I set filter_empty_gt=False in config files, the distributed training came across the following error: transforms.py: "need at least one array to stack" This is caused by Resize._resize_masks() and RandomFlip.call when there are no gts in annotation, I modified some code in Resize._resize_masks() as following:

            # original code:
            # results[key] = np.stack(masks)
            if masks:
                results[key] = np.stack(masks)
            else:
                if self.keep_ratio:
                    mask_size = (int(results['ori_shape'][1] * results['scale_factor'] + 0.5),
                                 int(results['ori_shape'][0] * results['scale_factor'] + 0.5))
                results[key] = np.empty((0,) + mask_size, dtype=np.uint8)

, and modified some code in RandomFlip.call as following:

    # original code:
                # results[key] = np.stack([
                #     mmcv.imflip(mask, direction=results['flip_direction'])
                #     for mask in results[key]
                # ])
                masks = [
                    mmcv.imflip(mask, direction=results['flip_direction'])
                    for mask in results[key]
                ]
                if len(masks) != 0:
                    results[key] = np.stack(masks)
                else:
                    mask_size = (results['img_shape'][1], results['img_shape'][0])
                    results[key] = np.empty((0,) + mask_size, dtype=np.uint8)

The modification solved the "need at least one array to stack" error, but the training still crash with some message: " RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one." So I set find_unused_parameters=True,NCCL_LL_THRESHOLD=0 as mentioned in some discussions, the training seems good, but after several iterations, the training stuck again, and some GPU occupation is 100%. Some deadlock happed. I think this is caused by the all_reduce operations in parse_losses function in train.py:

for loss_name, loss_value in log_vars.items():
    # reduce loss when distributed training
    if dist.is_available() and dist.is_initialized():
        # loss_value: tensor(0.6956, device='cuda:0', grad_fn= < AddBackward0 >)
        # loss_value.data: tensor(0.6956, device='cuda:0')
        # loss_value.data.item: 0.6956
        loss_value = loss_value.data.clone()
        dist.all_reduce(loss_value.div_(dist.get_world_size()))
    log_vars[loss_name] = loss_value.item()

When some batch has no gt, the 'loss_bbox' and 'loss_mask' in log_vars on some GPU may not exist, but other GPUs with gts may still perform all_reduce with 'loss_bbox' and 'loss_mask' across all GPUs, this operation will make the GPUs with gts wail permanently... I think setting "find_unused_parameters=True,NCCL_LL_THRESHOLD=0" just hide the problem rather than solving the problem in some discussions. I fixed it with the following modification:

    if len(log_vars) == 7:
        for loss_name, loss_value in log_vars.items():
            # reduce loss when distributed training
            if dist.is_available() and dist.is_initialized():
                # loss_value: tensor(0.6956, device='cuda:0', grad_fn= < AddBackward0 >)
                # loss_value.data: tensor(0.6956, device='cuda:0')
                # loss_value.data.item: 0.6956
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()
    elif len(log_vars) == 5:
        # 默认7个属性
        loss_names = list(log_vars.keys())
        loss_values = list(log_vars.values())
        # 先处理共有的5个属性
        for i in range(7):
            if i <= 3:
                loss_name = loss_names[i]
                loss_value = loss_values[i]
            elif i == 4:
                loss_name = 'loss_bbox'
                loss_value = torch.tensor(0.)
                if dist.is_available() and dist.is_initialized():
                    loss_value = loss_value.cuda()
            elif i == 5:
                loss_name = 'loss_mask'
                loss_value = torch.tensor(0.)
                if dist.is_available() and dist.is_initialized():
                    loss_value = loss_value.cuda()
            else:
                loss_name = loss_names[-1]
                loss_value = loss_values[-1]
            # print(f'lossname:{loss_name}, lossvalue:{loss_value}')
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()
    else:
        raise Exception('Not possible...')

The above modification is just temporary expedient, because the code in train.py is shared by all kinds of detectors. I think this should be fixed by modification the detectors class to make the classes return zeros when there are no gts in batch, also all_reduce performed on different groups should work too(I think...)

milleniums commented 4 years ago

Update

My NCCL version is 2.5.7. After setting the following env variable, the dead lock disapeared, but I got negative loss...

export NCCL_LL_THRESHOLD=0

@ZwwWayne @hellock When a batch contains only background images, the return losses of cascade rcnn misses keys s_i.bbox_loss, i=0,1,..,num_stages-1 , this is a expected behavior, since we can't do bbox regression when there are no gt_bboxes. But It causes my distributed trainning process hanging.

image

After I add this row in cascade_rcnn.py, the dead lock disapeared. But will it mess up the backpropagation?

image

so will it mess up the backpropagation?

panjianning commented 4 years ago

Update

My NCCL version is 2.5.7. After setting the following env variable, the dead lock disapeared, but I got negative loss...

export NCCL_LL_THRESHOLD=0

@ZwwWayne @hellock When a batch contains only background images, the return losses of cascade rcnn misses keys s_i.bbox_loss, i=0,1,..,num_stages-1 , this is a expected behavior, since we can't do bbox regression when there are no gt_bboxes. But It causes my distributed trainning process hanging. image After I add this row in cascade_rcnn.py, the dead lock disapeared. But will it mess up the backpropagation? image

so will it mess up the backpropagation?

I got worse score with this modification.

yhcao6 commented 4 years ago

This should be fixed by https://github.com/open-mmlab/mmdetection/pull/2280. Feel free to reopen it.