open-mmlab / mmtracking

OpenMMLab Video Perception Toolbox. It supports Video Object Detection (VID), Multiple Object Tracking (MOT), Single Object Tracking (SOT), Video Instance Segmentation (VIS) with a unified framework.
https://mmtracking.readthedocs.io/en/latest/
Apache License 2.0
3.52k stars 591 forks source link

RuntimeError: cannot reshape tensor of 0 elements into shape [0, 16, -1] because the unspecified dimension size -1 can be any value and is ambiguous #387

Open zhangy0210 opened 2 years ago

zhangy0210 commented 2 years ago

在ILSVRC2017上训练一个视频目标检测模型,在执行这句命令时

python tools/train.py configs/vid/temporal_roi_align/selsa_troialign_faster_rcnn_r50_dc5_7e_imagenetvid.py

报错

Traceback (most recent call last):
  File "tools/train.py", line 182, in <module>
    main()
  File "tools/train.py", line 178, in main
    meta=meta)
  File "/home/featurize/work/mmtracking/mmtrack/apis/train.py", line 175, in train_model
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
  File "/environment/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/environment/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 50, in train
    self.run_iter(data_batch, train_mode=True, **kwargs)
  File "/environment/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 30, in run_iter
    **kwargs)
  File "/environment/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/parallel/data_parallel.py", line 75, in train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
  File "/home/featurize/work/mmtracking/mmtrack/models/vid/base.py", line 265, in train_step
    losses = self(**data)
  File "/environment/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/environment/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/mmcv/runner/fp16_utils.py", line 98, in new_func
    return old_func(*args, **kwargs)
  File "/home/featurize/work/mmtracking/mmtrack/models/vid/base.py", line 194, in forward
    **kwargs)
  File "/home/featurize/work/mmtracking/mmtrack/models/vid/selsa.py", line 166, in forward_train
    gt_labels, gt_bboxes_ignore, gt_masks, **kwargs)
  File "/home/featurize/work/mmtracking/mmtrack/models/roi_heads/selsa_roi_head.py", line 66, in forward_train
    gt_bboxes, gt_labels)
  File "/home/featurize/work/mmtracking/mmtrack/models/roi_heads/selsa_roi_head.py", line 104, in _bbox_forward_train
    bbox_results = self._bbox_forward(x, ref_x, rois, ref_rois)
  File "/home/featurize/work/mmtracking/mmtrack/models/roi_heads/selsa_roi_head.py", line 93, in _bbox_forward
    cls_score, bbox_pred = self.bbox_head(bbox_feats, ref_bbox_feats)
  File "/environment/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/featurize/work/mmtracking/mmtrack/models/roi_heads/bbox_heads/selsa_bbox_head.py", line 57, in forward
    x = x + self.aggregator[i](x, ref_x)
  File "/environment/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/featurize/work/mmtracking/mmtrack/models/aggregators/selsa_aggregator.py", line 62, in forward
    -1).permute(1, 2, 0)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, 16, -1] because the unspecified dimension size -1 can be any value and is ambiguous

环境依赖,完全按照官方文档进行

# 文档
https://mmtracking.readthedocs.io/zh_CN/latest/install.html
# 环境
ubuntu20.04 / GTX3060 / CUDA-11.1

并跑通了最后的验证

python demo/demo_mot_vis.py configs/mot/deepsort/sort_faster-rcnn_fpn_4e_mot17-private.py --input demo/demo.mp4 --output mot.mp4

# btw,中英文档的这个命令不一样,英文可以运行而中文不行,中文文档将demo_mot_vis.py写成了demo_mot.py,而demo文件夹下根本没有这个文件

数据集方面,下载了ILSVRC2017的所有包

ILSVRC2017_DET.tar.gz
ILSVRC2017_DET_test_new.tar.gz
LSVRC2017_VID_new.tar.gz
ILSVRC2017_VID.tar.gz
ILSVRC2017_VID_test.tar.gz

下载了文档中附带的Lists文件夹的四个txt

DET_train_30classes.txt
VID_train_15frames.txt
VID_val_frames.txt
VID_val_videos.txt

使用文档中给出的转换脚本,转化为COCOVID格式

# ImageNet DET
python ./tools/convert_datasets/ilsvrc/imagenet2coco_det.py -i ./data/ILSVRC -o ./data/ILSVRC/annotations

# ImageNet VID
python ./tools/convert_datasets/ilsvrc/imagenet2coco_vid.py -i ./data/ILSVRC -o ./data/ILSVRC/annotations

我不太确定是不是数据集的问题,如果要运行VID任务,究竟需要下载哪几个文件?

如果不是数据集的问题,请尝试复现并解决这个问题

最后附上我的conda环境,是一个yml文件,使用conda env export > my-environment.yml 创建

可以使用conda env create -f my-environment.yml复制我的环境

name: open-mmlab
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=4.5=1_gnu
  - attrs=21.2.0=pyhd3eb1b0_0
  - ca-certificates=2021.10.26=h06a4308_2
  - certifi=2021.10.8=py37h06a4308_0
  - ld_impl_linux-64=2.35.1=h7274673_9
  - libffi=3.3=he6710b0_2
  - libgcc-ng=9.3.0=h5101ec6_17
  - libgomp=9.3.0=h5101ec6_17
  - libstdcxx-ng=9.3.0=hd4cf53a_17
  - ncurses=6.3=h7f8727e_2
  - openssl=1.1.1l=h7f8727e_0
  - pip=21.2.2=py37h06a4308_0
  - python=3.7.11=h12debd9_0
  - readline=8.1=h27cfd23_0
  - setuptools=58.0.4=py37h06a4308_0
  - sqlite=3.37.0=hc218d9a_0
  - tk=8.6.11=h1ccaba5_0
  - wheel=0.37.0=pyhd3eb1b0_1
  - xz=5.2.5=h7b6447c_0
  - zlib=1.2.11=h7f8727e_4
  - pip:
    - addict==2.4.0
    - attributee==0.1.3
    - bidict==0.21.4
    - cachetools==5.0.0
    - charset-normalizer==2.0.9
    - colorama==0.4.4
    - cycler==0.11.0
    - cython==0.29.26
    - dominate==2.6.0
    - dotty-dict==1.3.0
    - flake8==4.0.1
    - flake8-import-order==0.18.1
    - fonttools==4.28.5
    - idna==3.3
    - importlib-metadata==4.2.0
    - importlib-resources==5.4.0
    - iniconfig==1.1.1
    - jsonschema==4.3.3
    - kiwisolver==1.3.2
    - lap==0.4.0
    - llvmlite==0.37.0
    - matplotlib==3.5.1
    - mccabe==0.6.1
    - mmcls==0.19.0
    - mmcv-full==1.4.2
    - mmdet==2.20.0
    - motmetrics==1.2.0
    - numba==0.54.1
    - numpy==1.20.3
    - opencv-python==4.5.5.62
    - ordered-set==4.0.2
    - packaging==21.3
    - pandas==1.3.5
    - phx-class-registry==3.0.5
    - pillow==9.0.0
    - pluggy==1.0.0
    - py==1.11.0
    - py-cpuinfo==8.0.0
    - pycocotools==2.0.3
    - pycodestyle==2.8.0
    - pyflakes==2.4.0
    - pylatex==1.4.1
    - pyparsing==3.0.6
    - pyrsistent==0.18.0
    - pytest==6.2.5
    - pytest-benchmark==3.4.1
    - python-dateutil==2.8.2
    - pytz==2021.3
    - pyyaml==6.0
    - requests==2.27.0
    - scipy==1.7.3
    - seaborn==0.11.2
    - setuptools-scm==6.3.2
    - six==1.16.0
    - terminaltables==3.1.10
    - toml==0.10.2
    - tomli==2.0.0
    - torch==1.9.0+cu111
    - torchvision==0.10.0+cu111
    - tqdm==4.62.3
    - typing-extensions==4.0.1
    - urllib3==1.26.7
    - vot-toolkit==0.5.1
    - vot-trax==3.0.3
    - xmltodict==0.12.0
    - yapf==0.32.0
    - zipp==3.7.0
prefix: /environment/miniconda3/envs/open-mmlab

谢谢

GT9505 commented 2 years ago

please refer to https://github.com/open-mmlab/mmtracking/pull/375