open-mmlab / mmdetection3d

OpenMMLab's next-generation platform for general 3D object detection.
https://mmdetection3d.readthedocs.io/en/latest/
Apache License 2.0
5.35k stars 1.55k forks source link

[Bug] get_flops.py throws error for Pointpillars #3028

Open cadip92 opened 3 months ago

cadip92 commented 3 months ago

Prerequisite

Task

I'm using the official example scripts/configs for the officially supported tasks/models/datasets.

Branch

main branch https://github.com/open-mmlab/mmdetection3d

Environment

sys.platform: linux Python: 3.8.19 | packaged by conda-forge | (default, Mar 20 2024, 12:47:35) [GCC 12.3.0] CUDA available: False MUSA available: False numpy_random_seed: 2147483648 GCC: gcc (GCC) 11.4.1 20231218 (Red Hat 11.4.1-3) PyTorch: 2.0.1+cu117 PyTorch compiling details: PyTorch built with:

TorchVision: 0.15.2+cu117 OpenCV: 4.10.0 MMEngine: 0.10.4 MMDetection: 3.3.0 MMDetection3D: 1.4.0+962f093 spconv2.0: False

Reproduces the problem - code sample

python tools/analysis_tools/get_flops.py configs/pointpillars/pointpillars_hv_secfpn_sbn-all_8xb4-2x_nus-3d.py --modality point

Reproduces the problem - command or script

python tools/analysis_tools/get_flops.py configs/pointpillars/pointpillars_hv_secfpn_sbn-all_8xb4-2x_nus-3d.py --modality point

Reproduces the problem - error message

warnings.warn( Traceback (most recent call last): File "tools/analysis_tools/get_flops.py", line 83, in main() File "tools/analysis_tools/get_flops.py", line 73, in main flops, params = get_model_complexity_info(model, input_shape) File "/beegfs/chandorkar/miniforge3/envs/mmdet3d/lib/python3.8/site-packages/mmcv/cnn/utils/flops_counter.py", line 107, in get_model_complexityinfo = flops_model(batch) File "/beegfs/chandorkar/miniforge3/envs/mmdet3d/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl result = forward_call(*args, kwargs) File "/beegfs/chandorkar/projects/mmdetection3d/mmdet3d/models/detectors/base.py", line 88, in forward return self._forward(inputs, data_samples, kwargs) TypeError: _forward() takes 1 positional argument but 3 were given

Additional information

Seems like a bug in the code. I get the same error with the config file for kitti. configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py