open-mmlab / mmdeploy

OpenMMLab Model Deployment Framework
https://mmdeploy.readthedocs.io/en/latest/
Apache License 2.0
2.63k stars 610 forks source link

[Bug] Error Exporting RetinaNet for Single Class Case with CrossEntropyLoss in MMDeploy #2787

Open rchuzh99 opened 2 weeks ago

rchuzh99 commented 2 weeks ago

Checklist

Describe the bug

I am facing an issue when exporting the RetinaNet from mmdetection (https://github.com/open-mmlab/mmdetection/blob/main/configs/_base_/models/retinanet_r50_fpn.py) for a single class case. IndexError was raised due to the sliced nms_pre_score having zero-dim

I modified the classification loss function to employ Cross Entropy (type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), in such the effective bbox_head config would be as follow:

bbox_head=dict(
        type="RetinaHead",
        num_classes=80,
        in_channels=256,
        stacked_convs=4,
        feat_channels=256,
        anchor_generator=dict(
            type="AnchorGenerator",
            octave_base_scale=4,
            scales_per_octave=3,
            ratios=[0.5, 1.0, 2.0],
            strides=[8, 16, 32, 64, 128],
        ),
        bbox_coder=dict(type="DeltaXYWHBBoxCoder", target_means=[0.0, 0.0, 0.0, 0.0], target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
        loss_bbox=dict(type="L1Loss", loss_weight=1.0),
    )

To export the model into ONNX, I called the export function from https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/apis/onnx/export.py. Based on my understanding, before torch.onnx.export was invoked, the model is patched with modified child modules and for this particular case, the predict_with_feat() is replaced with base_dense_head__predict_by_feat() in https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py#L26-L27.

After reviewing the code in https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py#L26-L27,

I noticed three parts involving the use_sigmoid flag configured in the CrossEntropyLoss, namely:

  1. At the constructor of the RetinaHead : https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/dense_heads/anchor_head.py#L73-L78

  2. At the base_dense_head, there is first slicing of the scores. I presume this is to exclude the background (index num_classes): https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py#L113-L117

  3. This is the confusing part, there is a second round of slicing when getting the max_scores: https://github.com/open-mmlab/mmdeploy/blob/bc75c9d6c8940aa03d0e1e5b5962bd930478ba77/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py#L141-L146

I hope you could explain the reasoning behind this, as it appears that the last object class is excluded when computing the max_scores. Thank you!

Reproduction


  from mmdeploy.apis.onnx import export as onnx_export

  onnx_export(
        model=model,
        args=img,
        output_path_prefix=str(main_file),
        backend="onnxruntime",
        input_metas=input_metas,
        context_info=context_info,
        input_names=input_names,
        output_names=output_names,
        opset_version=11,
        dynamic_axes=dynamic_axes,
        verbose=False,
        keep_initializers_as_inputs=False,
        optimize=True,
    )

Environment

07/02 13:01:39 - mmengine - INFO - 

07/02 13:01:39 - mmengine - INFO - **********Environmental information**********
07/02 13:01:44 - mmengine - INFO - sys.platform: linux
07/02 13:01:44 - mmengine - INFO - Python: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]
07/02 13:01:44 - mmengine - INFO - CUDA available: True
07/02 13:01:44 - mmengine - INFO - MUSA available: False
07/02 13:01:44 - mmengine - INFO - numpy_random_seed: 2147483648
07/02 13:01:44 - mmengine - INFO - GPU 0: NVIDIA
07/02 13:01:44 - mmengine - INFO - CUDA_HOME: /usr/local/cuda
07/02 13:01:44 - mmengine - INFO - NVCC: Cuda compilation tools, release 12.3, V12.3.107
07/02 13:01:44 - mmengine - INFO - GCC: gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
07/02 13:01:44 - mmengine - INFO - PyTorch: 2.0.1
07/02 13:01:44 - mmengine - INFO - PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.7
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

07/02 13:01:44 - mmengine - INFO - TorchVision: 0.15.2
07/02 13:01:44 - mmengine - INFO - OpenCV: 4.9.0
07/02 13:01:44 - mmengine - INFO - MMEngine: 0.10.3
07/02 13:01:44 - mmengine - INFO - MMCV: 2.0.1
07/02 13:01:44 - mmengine - INFO - MMCV Compiler: GCC 9.3
07/02 13:01:44 - mmengine - INFO - MMCV CUDA Compiler: 11.8
07/02 13:01:44 - mmengine - INFO - MMDeploy: 1.3.1+bc75c9d
07/02 13:01:44 - mmengine - INFO - 

07/02 13:01:44 - mmengine - INFO - **********Backend information**********
07/02 13:01:44 - mmengine - INFO - tensorrt:    None
07/02 13:01:44 - mmengine - INFO - ONNXRuntime: None
07/02 13:01:44 - mmengine - INFO - pplnn:       None
07/02 13:01:44 - mmengine - INFO - ncnn:        None
07/02 13:01:45 - mmengine - INFO - snpe:        None
07/02 13:01:45 - mmengine - INFO - openvino:    None
07/02 13:01:45 - mmengine - INFO - torchscript: 2.0.1
07/02 13:01:45 - mmengine - INFO - torchscript custom ops:      NotAvailable
07/02 13:01:45 - mmengine - INFO - rknn-toolkit:        None
07/02 13:01:45 - mmengine - INFO - rknn-toolkit2:       None
07/02 13:01:45 - mmengine - INFO - ascend:      None
07/02 13:01:45 - mmengine - INFO - coreml:      None
07/02 13:01:45 - mmengine - INFO - tvm: None
07/02 13:01:45 - mmengine - INFO - vacc:        None
07/02 13:01:45 - mmengine - INFO - 

07/02 13:01:45 - mmengine - INFO - **********Codebase information**********
07/02 13:01:45 - mmengine - INFO - mmdet:       3.3.0
07/02 13:01:45 - mmengine - INFO - mmseg:       None
07/02 13:01:45 - mmengine - INFO - mmpretrain:  1.2.0
07/02 13:01:45 - mmengine - INFO - mmocr:       None
07/02 13:01:45 - mmengine - INFO - mmagic:      None
07/02 13:01:45 - mmengine - INFO - mmdet3d:     None
07/02 13:01:45 - mmengine - INFO - mmpose:      None
07/02 13:01:45 - mmengine - INFO - mmrotate:    None
07/02 13:01:45 - mmengine - INFO - mmaction:    None
07/02 13:01:45 - mmengine - INFO - mmrazor:     None
07/02 13:01:45 - mmengine - INFO - mmyolo:      None

Error traceback

│ xxx/lib/python3.10/site-packages/mmdeploy/codebase/mmdet/models │
│ /detectors/single_stage.py:85 in single_stage_detector__forward                                  │
│                                                                                                  │
│   82 │   # set the metainfo                                                                      │
│   83 │   data_samples = _set_metainfo(data_samples, img_shape)                                   │
│   84 │                                                                                           │
│ ❱ 85 │   return __forward_impl(self, batch_inputs, data_samples=data_samples)                    │
│   86                                                                                             │
│                                                                                                  │
│ xxx/lib/python3.10/site-packages/mmdeploy/core/optimizers/funct │
│ ion_marker.py:266 in g                                                                           │
│                                                                                                  │
│   263 │   │   │   args = mark_tensors(args, func, func_id, 'input', ctx, attrs,                  │
│   264 │   │   │   │   │   │   │   │   is_inspect, args_level)                                    │
│   265 │   │   │                                                                                  │
│ ❱ 266 │   │   │   rets = f(*args, **kwargs)                                                      │
│   267 │   │   │                                                                                  │
│   268 │   │   │   ctx = Context(output_names)                                                    │
│   269 │   │   │   func_ret = mark_tensors(rets, func, func_id, 'output', ctx, attrs,             │
│                                                                                                  │
│ xxx/lib/python3.10/site-packages/mmdeploy/codebase/mmdet/models │
│ /detectors/single_stage.py:23 in __forward_impl                                                  │
│                                                                                                  │
│   20 │   """                                                                                     │
│   21 │   x = self.extract_feat(batch_inputs)                                                     │
│   22 │                                                                                           │
│ ❱ 23 │   output = self.bbox_head.predict(x, data_samples, rescale=False)                         │
│   24 │   return output                                                                           │
│   25                                                                                             │
│   26                                                                                             │
│                                                                                                  │
│ xxx/lib/python3.10/site-packages/mmdet/models/dense_heads/base_ │
│ dense_head.py:197 in predict                                                                     │
│                                                                                                  │
│   194 │   │                                                                                      │
│   195 │   │   outs = self(x)                                                                     │
│   196 │   │                                                                                      │
│ ❱ 197 │   │   predictions = self.predict_by_feat(                                                │
│   198 │   │   │   *outs, batch_img_metas=batch_img_metas, rescale=rescale)                       │
│   199 │   │   return predictions                                                                 │
│   200                                                                                            │
│                                                                                                  │
│ xxx/lib/python3.10/site-packages/mmdeploy/codebase/mmdet/models │
│ /dense_heads/base_dense_head.py:145 in base_dense_head__predict_by_feat                          │
│                                                                                                  │
│   142 │   │   │   if self.use_sigmoid_cls:                                                       │
│   143 │   │   │   │   max_scores, _ = nms_pre_score.max(-1)                                      │
│   144 │   │   │   else:                                                                          │
│ ❱ 145 │   │   │   │   max_scores, _ = nms_pre_score[..., :-1].max(-1)                            │
│   146 │   │   │   _, topk_inds = max_scores.topk(pre_topk)                                       │
│   147 │   │   │   bbox_pred, scores, score_factors = gather_topk(                                │
│   148 │   │   │   │   bbox_pred,                                                                 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
IndexError: max(): Expected reduction dim 2 to have non-zero size.

/cc @RunningLeon @grimoire

rchuzh99 commented 2 weeks ago

2534 and https://github.com/open-mmlab/mmdeploy/commit/a51ee2c76caa1c8080e51981dccf829a23907791#diff-2871f924ffd987597f7ca6d4f5227f6d04fc49d3ca246a9bacd4b797e394206fR109