open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
29.72k stars 9.48k forks source link

Training Error when using Negative / background only Images #10256

Open aymanaboghonim opened 1 year ago

aymanaboghonim commented 1 year ago

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. I have read the FAQ documentation but cannot get the expected help.
  3. The bug has not been fixed in the latest version.

Describe the bug I could not start training with negative images (images that include no objects but background only) . when I start training , it threw this error message . image when I filter out empty images by setting filter-empty-gt = True, training is started normally which strongly indicated that is something related to negative images handling not an installation/config issue.

Reproduction

  1. What command or script did you run?
# Build dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_detector(cfg.model)

# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_detector(model, datasets, cfg, distributed=False, validate=True)```

2. Did you make any modifications on the code or config? Did you understand what you have modified?
3. What dataset did you use?

**Environment**

1. Please run `python mmdet/utils/collect_env.py` to collect necessary environment information and paste it here.
{'sys.platform': 'linux',
 'Python': '3.7.10 | packaged by conda-forge | (default, Oct 13 2021, 20:51:14) [GCC 9.4.0]',
 'CUDA available': True,
 'GPU 0': 'Tesla T4',
 'CUDA_HOME': '/usr/local/cuda',
 'NVCC': 'Cuda compilation tools, release 11.0, V11.0.221',
 'GCC': 'gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0',
 'PyTorch': '1.9.0',
 'PyTorch compiling details': 'PyTorch built with:\n  - GCC 7.3\n  - C++ Version: 201402\n  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications\n  - Intel(R) MKL-DNN v2.1.2 (Git Hash 98be7e8afa711dc9b66c8ff3504129cb82013cdb)\n  - OpenMP 201511 (a.k.a. OpenMP 4.5)\n  - NNPACK is enabled\n  - CPU capability usage: AVX2\n  - CUDA Runtime 11.1\n  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37\n  - CuDNN 8.0.5\n  - Magma 2.5.2\n  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/opt/rh/devtoolset-7/root/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.9.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, \n',
 'TorchVision': '0.10.0+cu111',
 'OpenCV': '4.7.0',
 'MMCV': '1.7.0',
 'MMCV Compiler': 'GCC 7.3',
 'MMCV CUDA Compiler': '11.1',
 'MMDetection': '2.28.0+'}
2. You may add addition that may be helpful for locating the problem, such as
   - How you installed PyTorch \[e.g., pip, conda, source\]
   - Other environment variables that may be related (such as `$PATH`, `$LD_LIBRARY_PATH`, `$PYTHONPATH`, etc.)

**Error traceback**
If applicable, paste the error trackback here.

```NotImplementedError                       Traceback (most recent call last)
/tmp/ipykernel_1/667614657.py in <module>
     19 # Create work_dir
     20 mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
---> 21 train_detector(model, datasets, cfg, distributed=False, validate=True)

/opt/conda/lib/python3.7/site-packages/mmdet/apis/train.py in train_detector(model, dataset, cfg, distributed, validate, timestamp, meta)
    244     elif cfg.load_from:
    245         runner.load_checkpoint(cfg.load_from)
--> 246     runner.run(data_loaders, cfg.workflow)

/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py in run(self, data_loaders, workflow, max_epochs, **kwargs)
    134                     if mode == 'train' and self.epoch >= self._max_epochs:
    135                         break
--> 136                     epoch_runner(data_loaders[i], **kwargs)
    137 
    138         time.sleep(1)  # wait for some hooks like loggers to finish

/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py in train(self, data_loader, **kwargs)
     51             self._inner_iter = i
     52             self.call_hook('before_train_iter')
---> 53             self.run_iter(data_batch, train_mode=True, **kwargs)
     54             self.call_hook('after_train_iter')
     55             del self.data_batch

/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py in run_iter(self, data_batch, train_mode, **kwargs)
     30         elif train_mode:
     31             outputs = self.model.train_step(data_batch, self.optimizer,
---> 32                                             **kwargs)
     33         else:
     34             outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)

/opt/conda/lib/python3.7/site-packages/mmcv/parallel/data_parallel.py in train_step(self, *inputs, **kwargs)
     75 
     76         inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
---> 77         return self.module.train_step(*inputs[0], **kwargs[0])
     78 
     79     def val_step(self, *inputs, **kwargs):

/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/base.py in train_step(self, data, optimizer)
    246                   averaging the logs.
    247         """
--> 248         losses = self(**data)
    249         loss, log_vars = self._parse_losses(losses)
    250 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py in new_func(*args, **kwargs)
    117                                 f'method of those classes {supported_types}')
    118             if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
--> 119                 return old_func(*args, **kwargs)
    120 
    121             # get the arg spec of the decorated method

/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/base.py in forward(self, img, img_metas, return_loss, **kwargs)
    170 
    171         if return_loss:
--> 172             return self.forward_train(img, img_metas, **kwargs)
    173         else:
    174             return self.forward_test(img, img_metas, **kwargs)

/opt/conda/lib/python3.7/site-packages/mmdet/models/detectors/two_stage.py in forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore, gt_masks, proposals, **kwargs)
    148                                                  gt_bboxes, gt_labels,
    149                                                  gt_bboxes_ignore, gt_masks,
--> 150                                                  **kwargs)
    151         losses.update(roi_losses)
    152 

/opt/conda/lib/python3.7/site-packages/mmdet/models/roi_heads/standard_roi_head.py in forward_train(self, x, img_metas, proposal_list, gt_bboxes, gt_labels, gt_bboxes_ignore, gt_masks, **kwargs)
    111             mask_results = self._mask_forward_train(x, sampling_results,
    112                                                     bbox_results['bbox_feats'],
--> 113                                                     gt_masks, img_metas)
    114             losses.update(mask_results['loss_mask'])
    115 

/opt/conda/lib/python3.7/site-packages/mmdet/models/roi_heads/point_rend_roi_head.py in _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks, img_metas)
     34         mask_results = super()._mask_forward_train(x, sampling_results,
     35                                                    bbox_feats, gt_masks,
---> 36                                                    img_metas)
     37         if mask_results['loss_mask'] is not None:
     38             loss_point = self._mask_point_forward_train(

/opt/conda/lib/python3.7/site-packages/mmdet/models/roi_heads/standard_roi_head.py in _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks, img_metas)
    150         if not self.share_roi_extractor:
    151             pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
--> 152             mask_results = self._mask_forward(x, pos_rois)
    153         else:
    154             pos_inds = []

/opt/conda/lib/python3.7/site-packages/mmdet/models/roi_heads/standard_roi_head.py in _mask_forward(self, x, rois, pos_inds, bbox_feats)
    185         if rois is not None:
    186             mask_feats = self.mask_roi_extractor(
--> 187                 x[:self.mask_roi_extractor.num_inputs], rois)
    188             if self.with_shared_head:
    189                 mask_feats = self.shared_head(mask_feats)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py in new_func(*args, **kwargs)
    206                                 'method of nn.Module')
    207             if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
--> 208                 return old_func(*args, **kwargs)
    209             # get the arg spec of the decorated method
    210             args_info = getfullargspec(old_func)

/opt/conda/lib/python3.7/site-packages/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py in forward(self, feats, rois, roi_scale_factor)
     45         """Forward function."""
     46         if len(feats) == 1:
---> 47             return self.roi_layers[0](feats[0], rois)
     48 
     49         out_size = self.roi_layers[0].output_size

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/mmcv/ops/point_sample.py in forward(self, features, rois)
    347                     point_feats.append(point_feat)
    348 
--> 349             point_feats = torch.cat(point_feats, dim=0)
    350 
    351         channels = features.size(1)

NotImplementedError: There were no tensor arguments to this function (e.g., you passed an empty list of Tensors), but no fallback function is registered for schema aten::_cat.  This usually means that this function requires a non-empty list of Tensors, or that you (the operator writer) forgot to register a fallback function.  Available functions are [CPU, CUDA, QuantizedCPU, BackendSelect, Named, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, UNKNOWN_TENSOR_TYPE_ID, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/build/aten/src/ATen/RegisterCPU.cpp:16286 [kernel]
CUDA: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/build/aten/src/ATen/RegisterCUDA.cpp:20674 [kernel]
QuantizedCPU: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/build/aten/src/ATen/RegisterQuantizedCPU.cpp:1025 [kernel]
BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
ADInplaceOrView: fallthrough registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/core/VariableFallbackKernel.cpp:60 [backend fallback]
AutogradOther: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
AutogradMLC: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
AutogradHPU: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
AutogradNestedTensor: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/VariableType_2.cpp:9928 [autograd kernel]
Tracer: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/torch/csrc/autograd/generated/TraceType_2.cpp:9621 [kernel]
Autocast: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/autocast_mode.cpp:259 [kernel]
Batched: registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/BatchingRegistrations.cpp:1019 [backend fallback]
VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1623448265233/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

Bug fix If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

aymanaboghonim commented 1 year ago

Does anyone can help ?? @Keiku @anurag1paul @hachreak @hhaAndroid

hhaAndroid commented 1 year ago

@aymanaboghonim Yes, it is possible that some algorithms do not handle it very robustly. Can you help to find out what exactly needs to be modified?

aymanaboghonim commented 1 year ago

@hhaAndroid Sorry, I could not catch any bug but I tried the same algo with the same settings few months ago and it was working properly with Negative images.

aymanaboghonim commented 1 year ago

Is there any fix for this bug ?

aymanaboghonim commented 1 year ago

hi @hhaAndroid , I faced the same issue with point rend in detectron2 but I could fix it following this solution https://github.com/facebookresearch/detectron2/issues/4383#issuecomment-1274648871 I tried to modify the ( /mmcv/ops/point_sample.py) which throws the error but I failed . can you help me fix this error based on the given solution . here is my trial : image `## edit to include negative examples
sR, sP, s2 = rel_roi_points.shape assert s2 == 2, rel_roi_points.shape if point_feats: point_feats = torch.cat(point_feats, dim=0) else: point_feats = torch.zeros((sR, sP), dtype=rel_roi_points.dtype, layout=rel_roi_points.layout, device=rel_roi_points.device)

point_feats = torch.cat(point_feats, dim=0)`