open-mmlab / mmdetection3d

OpenMMLab's next-generation platform for general 3D object detection.
https://mmdetection3d.readthedocs.io/en/latest/
Apache License 2.0
5.31k stars 1.54k forks source link

Training centerpoint model on Kitti lidar data #190

Closed YoushaaMurhij closed 3 years ago

YoushaaMurhij commented 4 years ago

I am facing this error while trying to test centerpoint on Kitti data: python tools/train.py configs/centerpoint/centerpoint_kitti.py \ --gpu-ids 0 --work-dir ./data/kitti/train_logs

Traceback (most recent call last):
  File "tools/train.py", line 166, in <module>
    main()
  File "tools/train.py", line 162, in main
    meta=meta)
  File "/opt/conda/lib/python3.7/site-packages/mmdet/apis/train.py", line 150, in train_detector
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
  File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 125, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/mmcv/runner/epoch_based_runner.py", line 47, in train
    for i, data_batch in enumerate(self.data_loader):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 363, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 989, in _next_data
    return self._process_data(data)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1014, in _process_data
    data.reraise()
  File "/opt/conda/lib/python3.7/site-packages/torch/_utils.py", line 395, in reraise
    raise self.exc_type(msg)
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.7/site-packages/mmdet/datasets/dataset_wrappers.py", line 151, in __getitem__
    return self.dataset[idx % self._ori_len]
  File "/mmdetection3d/mmdet3d/datasets/custom_3d.py", line 294, in __getitem__
    data = self.prepare_train_data(idx)
  File "/mmdetection3d/mmdet3d/datasets/custom_3d.py", line 147, in prepare_train_data
    input_dict = self.get_data_info(index)
  File "/mmdetection3d/mmdet3d/datasets/kitti_dataset.py", line 111, in get_data_info
    info = self.data_infos[index]
KeyError: 7

I am using this modified config file (only modified the dateset and base)

_base_ = [
    '../_base_/datasets/kitti-3d-3class.py',
    '../_base_/models/centerpoint_01voxel_second_secfpn_nus.py',
    '../_base_/schedules/cyclic_20e.py', '../_base_/default_runtime.py'
]

# If point cloud range is changed, the models should also change their point
# cloud range accordingly
class_names = ['Pedestrian', 'Cyclist', 'Car']
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
input_modality = dict(use_lidar=True, use_camera=False)

model = dict(
    pts_voxel_layer=dict(point_cloud_range=point_cloud_range),
    pts_bbox_head=dict(bbox_coder=dict(pc_range=point_cloud_range[:2])))
# model training and testing settings
train_cfg = dict(pts=dict(point_cloud_range=point_cloud_range))
test_cfg = dict(pts=dict(pc_range=point_cloud_range[:2]))
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
file_client_args = dict(backend='disk')

db_sampler = dict(
    data_root=data_root,
    info_path=data_root + 'kitti_dbinfos_train.pkl',
    rate=1.0,
    prepare=dict(
        filter_by_difficulty=[-1],
        filter_by_min_points=dict(
            Car=5, Pedestrian=10, Cyclist=10)),
    classes=class_names,
    sample_groups=dict(
        Car=12, Pedestrian=6, Cyclist=6),
    points_loader=dict(
        type='LoadPointsFromFile',
        load_dim=4,
        use_dim=4,
        file_client_args=file_client_args))

train_pipeline = [
    dict(
        type='LoadPointsFromFile',
        load_dim=4,
        use_dim=4,
        file_client_args=file_client_args),
    dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
    dict(type='ObjectSample', db_sampler=db_sampler),
    dict(
        type='GlobalRotScaleTrans',
        rot_range=[-0.78539816, 0.78539816],
        scale_ratio_range=[0.95, 1.05],
        translation_std=[1.0, 1.0, 0.5]),
    dict(
        type='RandomFlip3D',
        sync_2d=False,
        flip_ratio_bev_horizontal=0.5,
        flip_ratio_bev_vertical=0.5),
    dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
    dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
    dict(type='ObjectNameFilter', classes=class_names),
    dict(type='PointShuffle'),
    dict(type='DefaultFormatBundle3D', class_names=class_names),
    dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
    dict(
        type='LoadPointsFromFile',
        load_dim=4,
        use_dim=4,
        file_client_args=file_client_args),
    dict(
        type='MultiScaleFlipAug3D',
        img_scale=(1333, 800),
        pts_scale_ratio=1,
        flip=False,
        transforms=[
            dict(
                type='GlobalRotScaleTrans',
                rot_range=[0, 0],
                scale_ratio_range=[1., 1.],
                translation_std=[0, 0, 0]),
            dict(type='RandomFlip3D'),
            dict(
                type='PointsRangeFilter', point_cloud_range=point_cloud_range),
            dict(
                type='DefaultFormatBundle3D',
                class_names=class_names,
                with_label=False),
            dict(type='Collect3D', keys=['points'])
        ])
]

data = dict(
    train=dict(
        type='RepeatDataset',
        dataset=dict(
            type=dataset_type,
            data_root=data_root,
            ann_file=data_root + 'kitti_dbinfos_train.pkl',
            pipeline=train_pipeline,
            classes=class_names,
            test_mode=False,
            #use_valid_flag=True,
            # we use box_type_3d='LiDAR' in kitti and nuscenes dataset
            # and box_type_3d='Depth' in sunrgbd and scannet dataset.
            box_type_3d='LiDAR')),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file=data_root + 'kitti_infos_val.pkl',
        split='training',
        pts_prefix='velodyne_reduced',
        pipeline=test_pipeline,
        modality=input_modality,
        classes=class_names,
        test_mode=True,
        box_type_3d='LiDAR'),
    test=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file=data_root + 'kitti_infos_val.pkl',
        split='training',
        pts_prefix='velodyne_reduced',
        pipeline=test_pipeline,
        modality=input_modality,
        classes=class_names,
        test_mode=True,
        box_type_3d='LiDAR'))

evaluation = dict(interval=1)
Tai-Wang commented 4 years ago

Your modification is not complete actually. There are lots of differences between KITTI and nuScenes dataset, like the object classes, number of point features, whether to use lidar data from consecutive frames, valid point cloud range (only front view in KITTI), etc. You need to look into these differences more carefully and refer to other KITTI configs to modify the configs you intend to use.

tianweiy commented 4 years ago

@YoushaaMurhij Thanks for the interest. You basically need to follow the second config for KITTI-specific arguments. The network needs some tweaks and if i remembered correctly you also need to remove the velocity target in the training. My experience is that it is about the same accuracy as second for R11. We are working on a major method revision and hopefully come out with some more results on KITTI in a few months.

YoushaaMurhij commented 4 years ago

Thanks for your response! I will take you that into consideration.

hzh8311 commented 3 years ago

Is there any progress?

I tried to train centerpoint on KITTI datasets. I follow the KITTI configuration of pointpillars for data, train/val pipeline, (pp) voxelization, backbone, neck.

image

Thus, to work with CenterPoint, the bbox_head configuration comes from the original implement with minor modification list below

image

However, The trained model get ~40 mAP on KITTI valid set, and the aos is close to zeros. Any suggestions you guys?

yaodongC commented 3 years ago

Is there any progress?

I tried to train centerpoint on KITTI datasets. I follow the KITTI configuration of pointpillars for data, train/val pipeline, (pp) voxelization, backbone, neck.

image

Thus, to work with CenterPoint, the bbox_head configuration comes from the original implement with minor modification list below

image

However, The trained model get ~40 mAP on KITTI valid set, and the aos is close to zeros. Any suggestions you guys?

Is there any progress?

I tried to train centerpoint on KITTI datasets. I follow the KITTI configuration of pointpillars for data, train/val pipeline, (pp) voxelization, backbone, neck.

image

Thus, to work with CenterPoint, the bbox_head configuration comes from the original implement with minor modification list below

image

However, The trained model get ~40 mAP on KITTI valid set, and the aos is close to zeros. Any suggestions you guys?

Note: KITTI lidar frame is 90 degrees rotated from nuScenes lidar frame

Yaziwel commented 3 years ago

How to solve the problem of very low AOS scores?