dvlab-research / VoxelNeXt

VoxelNeXt: Fully Sparse VoxelNet for 3D Object Detection and Tracking (CVPR 2023)
https://arxiv.org/abs/2303.11301
Apache License 2.0
719 stars 62 forks source link

Use VoxelNext to training custom dataset #4

Closed luoxiaoliaolan closed 1 year ago

luoxiaoliaolan commented 1 year ago

Hello Yukangchen, thank you for your excellent work in the field of 3D detection and tracking. I carefully studied this repo project as soon as it was released, and I am very interested in the new backbone network VoxelNext. The network structure is concise and similar to VoxelRCNN. I tried to use VoxelNext to train my own dataset and see how it performs. My dataset is in a format similar to KITTI, with four dimensions for each point cloud frame: x, y, z, intensity, and annotation files divided into 11 classes. I have completed the related modules for the custom dataloader, and I can load the data for training normally and generate the corresponding pkl files after preprocessing. However, I encountered the following problem at the beginning of training: img_v2_45d32173-f62e-478b-a486-f64767eabacg assert boxes_a.shape[0] == boxes_b.shape[0] I have checked the progress of training my own data, and the error did not occur at the beginning but when running a certain frame, where boxes_a.shape[0] is 142 and boxes_b.shape[0] is 143. This caused an error and abnormal calculation of the IOU loss. Then I considered that this might be a data issue, but it only occurred in a very small number of data. I wonder if it is possible to ignore these data and set the loss to 0 when encountering such cases, without affecting the loss calculation of other normal data. So I modified the code to the following form: img_v2_b58764fa-dc6d-4d88-934e-e5695d2cf7bg img_v2_032092a1-d883-4a76-901f-0c1846fe7e4g After I made the modifications, the training was able to proceed normally. I don't know if these changes could have introduced any errors. After training a batch of data today, I obtained the trained model. However, when I tried to load the trained model for inference, I encountered a problem where the scores of the detection boxes were extremely low, less than 0.1. As a result, when I set the input score threshold for post-processing to be greater than 0.1, no detection results were outputted. I'm not sure what the reason for this is, so I was wondering if you could help me analyze it. I will continue to follow up on your project and hope to have more discussions with you. Below are my data and model configuration files: voxelnext_ioubranch.yaml: `CLASS_NAMES: ['car', 'pedestrian', 'cyclist', 'tricyclist', 'bus', 'truck', 'special_vehicle', 'traffic_cone', 'small_obstacle', 'traffic_facilities', 'other']

DATA_CONFIG: _BASECONFIG: cfgs/dataset_configs/at128_dataset.yaml OUTPUT_PATH: '/lpai/output/models'

MODEL: NAME: VoxelNeXt

VFE:
    NAME: MeanVFE

BACKBONE_3D:
    NAME: VoxelResBackBone8xVoxelNeXt

DENSE_HEAD:
    NAME: VoxelNeXtHead
    IOU_BRANCH: True
    CLASS_AGNOSTIC: False
    INPUT_FEATURES: 128

    CLASS_NAMES_EACH_HEAD: [
        ['car', 'pedestrian', 'cyclist', 'tricyclist', 'bus', 'truck', 'special_vehicle', 'traffic_cone',
          'small_obstacle', 'traffic_facilities', 'other']
    ]

    SHARED_CONV_CHANNEL: 128
    USE_BIAS_BEFORE_NORM: True
    NUM_HM_CONV: 2
    SEPARATE_HEAD_CFG:
        HEAD_ORDER: ['center', 'center_z', 'dim', 'rot']
        HEAD_DICT: {
            'center': {'out_channels': 2, 'num_conv': 2},
            'center_z': {'out_channels': 1, 'num_conv': 2},
            'dim': {'out_channels': 3, 'num_conv': 2},
            'rot': {'out_channels': 2, 'num_conv': 2},
            'iou': {'out_channels': 1, 'num_conv': 2},
        }
    RECTIFIER: [0.68, 0.71, 0.65, 0.5, 0.6, 0.67, 0.45, 0.4, 0.46, 0.5, 0.5]
    TARGET_ASSIGNER_CONFIG:
        FEATURE_MAP_STRIDE: 8
        NUM_MAX_OBJS: 500
        GAUSSIAN_OVERLAP: 0.1
        MIN_RADIUS: 2

    LOSS_CONFIG:
        LOSS_WEIGHTS: {
            'cls_weight': 1.0,
            'loc_weight': 2.0,
            'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
        }

    POST_PROCESSING:
        SCORE_THRESH: 0.5
        POST_CENTER_LIMIT_RANGE: [0.0, -40, -2, 150.4, 40, 4]
        MAX_OBJ_PER_SAMPLE: 500
        NMS_CONFIG:
            NMS_TYPE: nms_gpu

NMS_THRESH: [0.8, 0.55, 0.55] #0.7

            NMS_THRESH: [0.5, 0.3, 0.3, 0.3, 0.3, 0.5, 0.5, 0.5, 0.4, 0.3, 0.3]

NMS_PRE_MAXSIZE: [2048, 1024, 1024] #[4096]

            NMS_PRE_MAXSIZE: [2048, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]

NMS_POST_MAXSIZE: [200, 150, 150] #500

            NMS_POST_MAXSIZE: [500, 300, 400, 400, 200, 200, 200, 100, 100, 100, 100]

POST_PROCESSING:
    RECALL_THRESH_LIST: [0.5, 0.3, 0.3, 0.3, 0.3, 0.5, 0.5, 0.5, 0.4, 0.3, 0.3]

    EVAL_METRIC: kitti

OPTIMIZATION: BATCH_SIZE_PER_GPU: 26 NUM_EPOCHS: 50

OPTIMIZER: adam_onecycle
LR: 0.003
WEIGHT_DECAY: 0.01
MOMENTUM: 0.9

MOMS: [0.95, 0.85]
PCT_START: 0.4
DIV_FACTOR: 10
DECAY_STEP_LIST: [35, 45]
LR_DECAY: 0.1
LR_CLIP: 0.0000001

LR_WARMUP: False
WARMUP_EPOCH: 1

GRAD_NORM_CLIP: 10

custom_dataset.yaml: DATASET: 'At128Dataset' DATA_PATH: '/lpai/volumes/lpai-autopilot-root/autopilot-cloud/lidar_at128_dataset_1W/autopilot-cloud/lidar_at128_dataset_1W'

POINT_CLOUD_RANGE: [0.0, -40, -2, 150.4, 40, 4]

DATA_SPLIT: { 'train': train, 'test': val }

INFO_PATH: { 'train': [at128_infos_train.pkl], 'test': [at128_infos_val.pkl], }

BALANCED_RESAMPLING: True

GET_ITEM_LIST: ["points"]

DATA_AUGMENTOR: DISABLE_AUG_LIST: ['placeholder'] AUG_CONFIG_LIST:

POINT_FEATURE_ENCODING: { encoding_type: absolute_coordinates_encoding, used_feature_list: ['x', 'y', 'z', 'intensity'], src_feature_list: ['x', 'y', 'z', 'intensity'], }

DATA_PROCESSOR:

yukang2017 commented 1 year ago

Hi Yibo,

Thanks for your interest in our work. After carefully reading your issue, I suggest that there are two choices to try.

1. Disable IoU branch The IoU branch is just for the Waymo dataset. I think it might be useless for others. Set IOU_BRANCH: False in the config file.

And remove 'iou': {'out_channels': 1, 'num_conv': 2}, in the HEAD_DICT.

2. Try VoxelNeXt backbone with other head. After reading your comment, I think what your really like is the VoxelResBackBone8xVoxelNeXt backbone network. You can use it with other detector head, like Voxel R-CNN head. What you additionally need to do is just to convert the sparse out of VoxelResBackBone8xVoxelNeXt to dense, following the common VoxelResBackBone8x.

In addition, I am very happy for more discussion via WeChat. I will contact you latter.

Regards, Yukang Chen

hoangduyloc commented 1 year ago

Hi @luoxiaoliaolan, @yukang2017

I have problem in training custom dataset. My dataset has only Lidar (x, y, z), and it works normally with Second, Pointpillar, PVRCNN config, but with VoxelNext It wont work. I tried and have the same errors as you. Would you guide me abit to fix it.

Thank you both!