XuyangBai / TransFusion

[PyTorch] Official implementation of CVPR2022 paper "TransFusion: Robust LiDAR-Camera Fusion for 3D Object Detection with Transformers". https://arxiv.org/abs/2203.11496
Apache License 2.0
614 stars 76 forks source link

Object Query Initialization Methods #42

Open CeZh opened 2 years ago

CeZh commented 2 years ago

For both voxelnet and the pillarnet with lidar only, the object query initialization methods are "initialization by heatmap." But it also leaves the option of random initialization with learnable query positions. I tried to train it, but the loss is stuck around four after several epochs. I'm wondering whether you have tried to train with random initialization with learnable query positions and how many epochs it takes to converge. Thank you so much!

XuyangBai commented 2 years ago

Hi @CeZh I have done some ablation study in Table 6 of our paper. The model trained without input-dependent strategy is the one with random initialization of queries. You have to increase the number of decoder layers with auxiliary supervision (so as to update the query positions gradually) and the training epoch to achieve a not-too-bad results when randomly initializing the object query. And if I remember correctly, I still add the heatmap loss supervision even though we do not use it to initialize the query.

CeZh commented 2 years ago

Hello Xuyang, Thank you so much for your quick reply. May I keep this discussion open for a while so that I can follow up some questions if the random initialization query method has some issues during my training. Thank you!

XuyangBai commented 2 years ago

Sure, discussion is always welcome :)

CeZh commented 2 years ago

Hello Xuyang, sorry to bother you again. When I train the transfusion_nusc_pillar_L with random initialization object queries, I found that you have an undefined variable self.query_labels (transfusion_head line 1300) in the function get_bboxes I understand that this self.query_labels is designed for heatmap initialization method. However, for the random initialization method, since the query features and query positions are one dimension (Batch128200 & Batch 2 200 without classes), I think we should no longer use this variable, is that right? if I remove this variable, should I use the batch_score obtained from the prediction heatmap results or something else? To better describe it, I have attached part of your code here. Thank you so much!

   for layer_id, preds_dict in enumerate(preds_dicts):
        batch_size = preds_dict[0]['heatmap'].shape[0]
        batch_score = preds_dict[0]['heatmap'][..., -self.num_proposals:].sigmoid()
        # if self.loss_iou.loss_weight != 0:
        #    batch_score = torch.sqrt(batch_score * preds_dict[0]['iou'][..., -self.num_proposals:].sigmoid())
        one_hot = F.one_hot(self.query_labels, num_classes=self.num_classes).permute(0, 2, 1)
        batch_score = batch_score * preds_dict[0]['query_heatmap_score'] * one_hot

In this line: one_hot = F.one_hot(self.query_labels, num_classes=self.num_classes).permute(0, 2, 1) , the self.query_labels is not defined in random initialization. Is it correct to comment on it and also comment on the next line for batch_score calculation? Therefore, the batch_score will only be the output from the preds_dict heatmap. Thank you for your time!

XuyangBai commented 2 years ago

Yes, if you strictly use the current implementation and random query initialization, the final batch_score should just be preds_dict[0]['heatmap'][..., -self.num_proposals:].sigmoid() so you need to comment out L1300 and 1301.

CeZh commented 2 years ago

Hello Xuyang, I have trained a transfusion model with random initialization and the mAP result maintains 0 during the training. This model is with point pillar as the backbone. Similar to your training steps, I started to train the point cloud only but since the results at point cloud only is 0 so I didn't move forward. The mAP curve and the loss curve are looking like the following figures: image mAP results

image Training Loss

Here is the details about the model and training parameters setup: I believe that the main thing I modified the learnable query position to True and the heat-map initialization to False. Would you please help me to take a look and tell me where I did wrong? Thank you so much


point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
class_names = [
    'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
    'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]
voxel_size = [0.2, 0.2, 8]
out_size_factor = 4
evaluation = dict(interval=1)
dataset_type = 'NuScenesDataset'
data_root = 'some data root'
input_modality = dict(
    use_lidar=True,
    use_camera=False,
    use_radar=False,
    use_map=False,
    use_external=False)
train_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=5,
        use_dim=[0, 1, 2, 3, 4],
    ),
    dict(
        type='LoadPointsFromMultiSweeps',
        sweeps_num=10,
        use_dim=[0, 1, 2, 3, 4],
    ),
    dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
    dict(
        type='ObjectSample',
        db_sampler=dict(
            data_root=None,
            info_path=data_root + 'nuscenes_dbinfos_train.pkl',
            rate=1.0,
            prepare=dict(
                filter_by_difficulty=[-1],
                filter_by_min_points=dict(
                    car=5,
                    truck=5,
                    bus=5,
                    trailer=5,
                    construction_vehicle=5,
                    traffic_cone=5,
                    barrier=5,
                    motorcycle=5,
                    bicycle=5,
                    pedestrian=5)),
            classes=class_names,
            sample_groups=dict(
                car=2,
                truck=3,
                construction_vehicle=7,
                bus=4,
                trailer=6,
                barrier=2,
                motorcycle=6,
                bicycle=6,
                pedestrian=2,
                traffic_cone=2),
            points_loader=dict(
                type='LoadPointsFromFile',
                coord_type='LIDAR',
                load_dim=5,
                use_dim=[0, 1, 2, 3, 4],
            ))),
    dict(
        type='GlobalRotScaleTrans',
        rot_range=[-0.3925 * 2, 0.3925 * 2],
        scale_ratio_range=[0.9, 1.1],
        translation_std=[0.5, 0.5, 0.5]),
    dict(
        type='RandomFlip3D',
        sync_2d=False,
        flip_ratio_bev_horizontal=0.5,
        flip_ratio_bev_vertical=0.5),
    dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
    dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
    dict(type='ObjectNameFilter', classes=class_names),
    dict(type='PointShuffle'),
    dict(type='DefaultFormatBundle3D', class_names=class_names),
    dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
    dict(
        type='LoadPointsFromFile',
        coord_type='LIDAR',
        load_dim=5,
        use_dim=[0, 1, 2, 3, 4],
    ),
    dict(
        type='LoadPointsFromMultiSweeps',
        sweeps_num=10,
        use_dim=[0, 1, 2, 3, 4],
    ),
    dict(
        type='MultiScaleFlipAug3D',
        img_scale=(1333, 800),
        pts_scale_ratio=1,
        flip=False,
        transforms=[
            dict(
                type='GlobalRotScaleTrans',
                rot_range=[0, 0],
                scale_ratio_range=[1.0, 1.0],
                translation_std=[0, 0, 0]),
            dict(type='RandomFlip3D'),
            dict(
                type='DefaultFormatBundle3D',
                class_names=class_names,
                with_label=False),
            dict(type='Collect3D', keys=['points'])
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=0,
    train=dict(
        type='CBGSDataset',
        dataset=dict(
            type=dataset_type,
            data_root=data_root,
            ann_file=data_root + '/nuscenes_infos_train.pkl',
            load_interval=1,
            pipeline=train_pipeline,
            classes=class_names,
            modality=input_modality,
            test_mode=False,
            box_type_3d='LiDAR')),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file=data_root + '/nuscenes_infos_val.pkl',
        load_interval=1,
        pipeline=test_pipeline,
        classes=class_names,
        modality=input_modality,
        test_mode=True,
        box_type_3d='LiDAR'),
    test=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file=data_root + '/nuscenes_infos_val.pkl',
        load_interval=1,
        pipeline=test_pipeline,
        classes=class_names,
        modality=input_modality,
        test_mode=True,
        box_type_3d='LiDAR'))
model = dict(
    type='TransFusionDetector',
    pts_voxel_layer=dict(
        max_num_points=20,
        voxel_size=voxel_size,
        max_voxels=(30000, 60000),
        point_cloud_range=point_cloud_range),
    pts_voxel_encoder=dict(
        type='PillarFeatureNet',
        in_channels=5,
        feat_channels=[64],
        with_distance=False,
        voxel_size=voxel_size,
        norm_cfg=dict(type='BN1d', eps=0.001, momentum=0.01),
        point_cloud_range=point_cloud_range,
    ),
    pts_middle_encoder=dict(
        type='PointPillarsScatter', in_channels=64, output_shape=(512, 512)
    ),
    pts_backbone=dict(
        type='SECOND',
        in_channels=64,
        out_channels=[64, 128, 256],
        layer_nums=[3, 5, 5],
        layer_strides=[2, 2, 2],
        norm_cfg=dict(type='BN', eps=0.001, momentum=0.01),
        conv_cfg=dict(type='Conv2d', bias=False)),
    pts_neck=dict(
        type='SECONDFPN',
        in_channels=[64, 128, 256],
        out_channels=[128, 128, 128],
        upsample_strides=[0.5, 1, 2],
        norm_cfg=dict(type='BN', eps=0.001, momentum=0.01),
        upsample_cfg=dict(type='deconv', bias=False),
        use_conv_for_no_stride=True),
    pts_bbox_head=dict(
        type='TransFusionHead',
        num_proposals=200,
        auxiliary=True,
        in_channels=128 * 3,
        hidden_channel=128,
        num_classes=len(class_names),
        num_decoder_layers=1,
        num_heads=8,
        learnable_query_pos=True,
        initialize_by_heatmap=False,
        nms_kernel_size=3,
        ffn_channel=256,
        dropout=0.1,
        bn_momentum=0.1,
        activation='relu',
        common_heads=dict(center=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
        bbox_coder=dict(
            type='TransFusionBBoxCoder',
            pc_range=point_cloud_range[:2],
            voxel_size=voxel_size[:2],
            out_size_factor=out_size_factor,
            post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
            score_threshold=0.0,
            code_size=10,
        ),
        loss_cls=dict(type='FocalLoss', use_sigmoid=True, gamma=2, alpha=0.25, reduction='mean', loss_weight=1.0),
        # loss_iou=dict(type='CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=0.0),
        loss_bbox=dict(type='L1Loss', reduction='mean', loss_weight=0.25),
        loss_heatmap=dict(type='GaussianFocalLoss', reduction='mean', loss_weight=1.0),
    ),
    train_cfg=dict(
        pts=dict(
            dataset='nuScenes',
            assigner=dict(
                type='HungarianAssigner3D',
                iou_calculator=dict(type='BboxOverlaps3D', coordinate='lidar'),
                cls_cost=dict(type='FocalLossCost', gamma=2, alpha=0.25, weight=0.15),
                reg_cost=dict(type='BBoxBEVL1Cost', weight=0.25),
                iou_cost=dict(type='IoU3DCost', weight=0.25)
            ),
            pos_weight=-1,
            gaussian_overlap=0.1,
            min_radius=2,
            grid_size=[512, 512, 1],  # [x_len, y_len, 1]
            voxel_size=voxel_size,
            out_size_factor=out_size_factor,
            code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2],
            point_cloud_range=point_cloud_range)),
    test_cfg=dict(
        pts=dict(
            dataset='nuScenes',
            grid_size=[512, 512, 1],
            out_size_factor=out_size_factor,
            pc_range=point_cloud_range[0:2],
            voxel_size=voxel_size[:2],
            nms_type=None,
        )))
optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.01)  # for 8gpu * 2sample_per_gpu
optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2))
lr_config = dict(
    policy='cyclic',
    target_ratio=(10, 0.0001),
    cyclic_times=1,
    step_ratio_up=0.4)
momentum_config = dict(
    policy='cyclic',
    target_ratio=(0.8947368421052632, 1),
    cyclic_times=1,
    step_ratio_up=0.4)
total_epochs = 15
checkpoint_config = dict(interval=1)
log_config = dict(
    interval=50,
    hooks=[dict(type='TextLoggerHook'),
           dict(type='TensorboardLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = None
load_from = None
resume_from = None
workflow = [('train', 1)]
XuyangBai commented 2 years ago

@CeZh Could you provide all the curves or training logs? I need the heatmap loss, cls loss and bbox loss to figure out what is going on. And did you use the heatmap loss supervision? You need to change the code a little bit because by default the heatmap loss is not used if initialize_by_heatmap is set to False (I didn't verify the logic of ablation experiments during code releasing, sorry about that). But I remember I still use the heatmap supervision even when we do not use it to initialize the query.

A quick (but not elegant) solution is: still set initialize_by_heatmap to True and learnable_query_pose to False, but create learnable query embedding and query position here: https://github.com/XuyangBai/TransFusion/blob/399bda09a3b6449313ccc302df40651f77ec78bf/mmdet3d/models/dense_heads/transfusion_head.py#L692-L693

and initialize the query using self.query_feat and self.query_pose instead of using the predict heatmap (commenting out the following part and replacing it with L875-L876, and it seems you need to rename base_xyz to query_pos). https://github.com/XuyangBai/TransFusion/blob/399bda09a3b6449313ccc302df40651f77ec78bf/mmdet3d/models/dense_heads/transfusion_head.py#L861-L871

You can try with adding the heatmap supervision first :)

CeZh commented 2 years ago

Hello Xuyang, Thank you so much for your reply and explanations. I got hanged by other stuff these days and I haven't got a chance to re-train the model as you suggested. I will do it next week and let you know. Thanks again for your comments and suggestions.

HappyLuffe commented 1 year ago

Hello @CeZh, Thank you for raising the issue, I learned a lot from it. I want to know if it works with random initialization with learnable query positions, cause I have the same problem as you. Thanks again for all the work you do.

HappyLuffe commented 1 year ago

Hello Xuyang, I tried using heatmap loss supervision with random initialization with learnable query positions, but it doesn't look like it's working. Should I use other loss functions instead of heatmap loss?