XuyangBai / TransFusion

[PyTorch] Official implementation of CVPR2022 paper "TransFusion: Robust LiDAR-Camera Fusion for 3D Object Detection with Transformers". https://arxiv.org/abs/2203.11496
Apache License 2.0
619 stars 76 forks source link

Loss abnormal with waymo TransFusion_L config #20

Open idontlikelongname opened 2 years ago

idontlikelongname commented 2 years ago

Thanks for your great works.

Im training TransFusion_L (lidar only) with waymo datasets. but the loss is abnormal, in particular the box size loss. Also the box size in test reults is weird. Why nuscenes loss_bbox weights is 0.25, but waymo is 2.0, its make the waymo box loss very big, over 10.0 at last.

the following is my codes, config and training logs, I simplify the TransFusionHead to make it easy to understand for me.

voxel_size = [0.32, 0.32, 6]
model = dict(
    type='Transfusion',
    pts_voxel_layer=dict(
        max_num_points=20,
        point_cloud_range=[-76.8, -76.8, -2, 76.8, 76.8, 4],
        voxel_size=voxel_size,
        max_voxels=(150000, 150000),
    ),
    pts_voxel_encoder=dict(
        type='HardVFE',
        in_channels=5,
        feat_channels=[64, 64],
        voxel_size=voxel_size,
        with_cluster_center=True,
        with_voxel_center=True,
        with_distance=False,
        point_cloud_range=[-76.8, -76.8, -2, 76.8, 76.8, 4],
        norm_cfg=dict(type='naiveSyncBN1d', eps=1e-3, momentum=0.01)),
    pts_middle_encoder=dict(
        type='PointPillarsScatter', in_channels=64, output_shape=(480, 480)),
    pts_backbone=dict(
        type='SECOND',
        in_channels=64,
        norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01),
        layer_nums=[3, 5, 5],
        layer_strides=[1, 2, 2],
        out_channels=[64, 128, 256]),
    pts_neck=dict(
        type='SECONDFPN',
        norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01),
        in_channels=[64, 128, 256],
        upsample_strides=[1, 2, 4],
        out_channels=[128, 128, 128]),
    pts_bbox_head=dict(
        type='TransFusionHead',
        num_proposals=100,
        auxiliary=True,
        in_channels=sum([128, 128, 128]),
        hidden_channel=128,
        num_classes=3,
        num_decoder_layers=1,
        num_heads=8,
        nms_kernel_size=[3, 1, 1],
        ffn_channel=256,
        dropout=0.1,
        bn_momentum=0.1,
        activation='relu',
        common_heads=dict(
            center=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2)),
        bbox_coder=dict(
            type='TransFusionBBoxCoder',
            pc_range=[-76.8, -76.8],
            voxel_size=voxel_size[:2],
            out_size_factor=1,
            post_center_range=[-80, -80, -10.0, 80, 80, 10.0],
            score_threshold=0.0,
            code_size=8,
        ),
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2,
            alpha=0.25,
            reduction='mean',
            loss_weight=1.0),
        # loss_iou=dict(type='CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=0.0),
        loss_bbox=dict(type='L1Loss', reduction='mean', loss_weight=2.0),
        loss_heatmap=dict(
            type='GaussianFocalLoss', reduction='mean', loss_weight=1.0),
    ),
    # model training and testing settings
    train_cfg=dict(
        pts=dict(
            grid_size=[480, 480, 1],
            out_size_factor=1,
            voxel_size=voxel_size,
            point_cloud_range=[-76.8, -76.8, -2, 76.8, 76.8, 4],
            dense_reg=1,
            gaussian_overlap=0.1,
            max_objs=500,
            min_radius=[2, 2, 2],
            code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
            pos_weight=-1,
            assigner=dict(
                type='HungarianAssigner3D',
                iou_calculator=dict(type='BboxOverlaps3D', coordinate='lidar'),
                cls_cost=dict(
                    type='FocalLossCost', gamma=2, alpha=0.25, weight=0.6),
                reg_cost=dict(type='BBoxBEVL1Cost', weight=2.0),
                iou_cost=dict(type='IoU3DCost', weight=2.0)),
        )),
    test_cfg=dict(
        pts=dict(
            grid_size=[480, 480, 1],
            out_size_factor=1,
            nms_type='circle',
            min_radius=[[1, 1, 1]],
            post_max_size=200,
            nms_iou_threshold=0.75,
        )))

@HEADS.register_module()
class TransFusionHead(nn.Module):

    def __init__(
        self,
        num_classes=3,
        in_channels=128 * 3,
        hidden_channel=128,
        auxiliary=True,
        # config for Transformer
        num_proposals=300,
        num_decoder_layers=1,
        num_heads=8,
        learnable_query_pos=False,
        initialize_by_heatmap=True,
        nms_kernel_size=[3, 1, 1],
        ffn_channel=256,
        dropout=0.1,
        bn_momentum=0.1,
        activation='relu',
        # config for FFN
        common_heads=dict(
            center=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2)),
        num_heatmap_convs=2,
        conv_cfg=dict(type='Conv1d'),
        norm_cfg=dict(type='BN1d'),
        bias='auto',
        # loss
        loss_cls=dict(type='GaussianFocalLoss', reduction='mean'),
        # loss_iou=dict(
        #     type='VarifocalLoss',
        #     use_sigmoid=True,
        #     iou_weighted=True,
        #     reduction='mean'),
        loss_bbox=dict(type='L1Loss', reduction='mean'),
        loss_heatmap=dict(type='GaussianFocalLoss', reduction='mean'),
        # others
        train_cfg=None,
        test_cfg=None,
        bbox_coder=None,
    ):
        super(TransFusionHead, self).__init__()

        self.num_classes = num_classes
        self.in_channels = in_channels
        # transformer
        self.learnable_query_pos = learnable_query_pos
        # if return all results after transformer decoder
        self.auxiliary = auxiliary
        self.num_proposals = num_proposals
        self.num_heads = num_heads
        self.num_decoder_layers = num_decoder_layers
        self.bn_momentum = bn_momentum
        self.nms_kernel_size = nms_kernel_size

        self.initialize_by_heatmap = initialize_by_heatmap
        if self.initialize_by_heatmap is True:
            assert self.learnable_query_pos is False, "initialized by heatmap is conflicting with learnable query position"
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

        self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
        if not self.use_sigmoid_cls:
            self.num_classes += 1

        self.loss_cls = build_loss(loss_cls)
        self.loss_bbox = build_loss(loss_bbox)
        # self.loss_iou = build_loss(loss_iou)
        self.loss_heatmap = build_loss(loss_heatmap)

        self.bbox_coder = build_bbox_coder(bbox_coder)
        self.sampling = False

        # a shared convolution
        self.shared_conv = build_conv_layer(
            dict(type='Conv2d'),
            in_channels,
            hidden_channel,
            kernel_size=3,
            padding=1,
            bias=bias,
        )

        if self.initialize_by_heatmap:
            layers = []
            layers.append(
                ConvModule(
                    hidden_channel,
                    hidden_channel,
                    kernel_size=3,
                    padding=1,
                    bias=bias,
                    conv_cfg=dict(type='Conv2d'),
                    norm_cfg=dict(type='BN2d'),
                ))
            layers.append(
                build_conv_layer(
                    dict(type='Conv2d'),
                    hidden_channel,
                    num_classes,
                    kernel_size=3,
                    padding=1,
                    bias=bias,
                ))
            self.heatmap_head = nn.Sequential(*layers)
            self.class_encoding = nn.Conv1d(num_classes, hidden_channel, 1)
        else:
            # query feature
            self.query_feat = nn.Parameter(
                torch.randn(1, hidden_channel, self.num_proposals))
            self.query_pos = nn.Parameter(
                torch.rand([1, self.num_proposals, 2]),
                requires_grad=learnable_query_pos)

        # transformer decoder layers for object query with LiDAR feature
        self.decoder = nn.ModuleList()
        for i in range(self.num_decoder_layers):
            self.decoder.append(
                TransformerDecoderLayer(
                    hidden_channel,
                    num_heads,
                    ffn_channel,
                    dropout,
                    activation,
                    self_posembed=PositionEmbeddingLearned(2, hidden_channel),
                    cross_posembed=PositionEmbeddingLearned(2, hidden_channel),
                ))

        # Prediction Head
        self.prediction_heads = nn.ModuleList()
        for i in range(self.num_decoder_layers):
            heads = copy.deepcopy(common_heads)
            heads.update(dict(heatmap=(self.num_classes, num_heatmap_convs)))
            self.prediction_heads.append(
                FFN(hidden_channel,
                    heads,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    bias=bias))

        self.init_weights()
        self._init_assigner_sampler()

        # Position Embedding for Cross-Attention, which is re-used during training
        x_size = self.test_cfg['grid_size'][0] // self.test_cfg[
            'out_size_factor']
        y_size = self.test_cfg['grid_size'][1] // self.test_cfg[
            'out_size_factor']
        # [1, H*W, 2]
        self.bev_pos = self.create_2D_grid(x_size, y_size)

        self.img_feat_pos = None
        self.img_feat_collapsed_pos = None

    def create_2D_grid(self, x_size, y_size):
        meshgrid = [[0, x_size - 1, x_size], [0, y_size - 1, y_size]]
        batch_y, batch_x = torch.meshgrid(
            *[torch.linspace(it[0], it[1], it[2]) for it in meshgrid])
        batch_x = batch_x + 0.5
        batch_y = batch_y + 0.5
        coord_base = torch.cat([batch_x[None], batch_y[None]], dim=0)[None]
        coord_base = coord_base.view(1, 2, -1).permute(0, 2, 1)
        return coord_base

    def init_weights(self):
        # initialize transformer
        for m in self.decoder.parameters():
            if m.dim() > 1:
                nn.init.xavier_uniform_(m)
        if hasattr(self, 'query'):
            nn.init.xavier_normal_(self.query)
        self.init_bn_momentum()

    def init_bn_momentum(self):
        for m in self.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                m.momentum = self.bn_momentum

    def _init_assigner_sampler(self):
        """Initialize the target assigner and sampler of the head."""
        if self.train_cfg is None:
            return

        if self.sampling:
            self.bbox_sampler = build_sampler(self.train_cfg.sampler)
        else:
            self.bbox_sampler = PseudoSampler()
        if isinstance(self.train_cfg.assigner, dict):
            self.bbox_assigner = build_assigner(self.train_cfg.assigner)
        elif isinstance(self.train_cfg.assigner, list):
            self.bbox_assigner = [
                build_assigner(res) for res in self.train_cfg.assigner
            ]

    def forward_single(self, inputs):
        """Forward function for CenterPoint.

        Args:
            inputs (torch.Tensor): input BEV features for multi-level from FPN, in shape of
                [B, C, H, W].

        Returns:
            list[dict]: Output results for tasks.
        """
        batch_size = inputs.shape[0]

        lidar_feat = self.shared_conv(inputs)

        #################################
        # image to BEV
        #################################
        # [BS, hidden_C, H*W]
        lidar_feat_flatten = lidar_feat.view(batch_size, lidar_feat.shape[1],
                                             -1)
        # [1, H*W, 2] -> [B, H*W, 2]
        bev_pos = self.bev_pos.repeat(batch_size, 1, 1).to(lidar_feat.device)

        #################################
        # image guided query initialization
        #################################
        if self.initialize_by_heatmap:
            dense_heatmap = self.heatmap_head(lidar_feat)
            heatmap = dense_heatmap.detach().sigmoid()
            _, cls_nums, _, _ = heatmap.shape
            local_max = []
            for cls_id in range(cls_nums):
                cls_heatmap_neighbor_max = F.max_pool2d(
                    heatmap[:, cls_id:(cls_id + 1), :, :],
                    self.nms_kernel_size[cls_id],
                    stride=1,
                    padding=self.nms_kernel_size[cls_id] // 2)
                local_max.append(cls_heatmap_neighbor_max)
            local_max = torch.cat(local_max, axis=1)

            # padding = self.nms_kernel_size // 2
            # local_max = torch.zeros_like(heatmap)
            # # equals to nms radius = voxel_size * out_size_factor * kenel_size
            # local_max_inner = F.max_pool2d(
            #     heatmap, kernel_size=self.nms_kernel_size, stride=1, padding=0)
            # local_max[:, :, padding:(-padding),
            #           padding:(-padding)] = local_max_inner
            # if self.test_cfg[
            #         'dataset'] == 'Waymo':  # for Pedestrian & Cyclist in Waymo
            #     local_max[:, 1, ] = F.max_pool2d(
            #         heatmap[:, 1], kernel_size=1, stride=1, padding=0)
            #     local_max[:, 2, ] = F.max_pool2d(
            #         heatmap[:, 2], kernel_size=1, stride=1, padding=0)
            heatmap = heatmap * (heatmap == local_max)

            # top_K proposals from heatmap for all classes
            heatmap = heatmap.view(batch_size, heatmap.shape[1], -1)
            top_proposals = heatmap.view(batch_size, -1).argsort(
                dim=-1, descending=True)[..., :self.num_proposals]
            top_proposals_class = top_proposals // heatmap.shape[-1]
            top_proposals_index = top_proposals % heatmap.shape[-1]
            # top_K with lidar BEV features, [B, F_C, H*W], ex. [B, 128, 300]
            query_feat = lidar_feat_flatten.gather(
                index=top_proposals_index[:, None, :].expand(
                    -1, lidar_feat_flatten.shape[1], -1),
                dim=-1)
            # [B, N], ex. [B, 300]
            self.query_labels = top_proposals_class

            # add category embedding
            # [B, N] -> [B, N, Cls_C] -> [B, Cls_C, N]
            one_hot = F.one_hot(
                top_proposals_class,
                num_classes=self.num_classes).permute(0, 2, 1)
            # [B, Cls_C, N] -> [B, F_C, N]
            query_cat_encoding = self.class_encoding(one_hot.float())
            query_feat += query_cat_encoding

            # [B, H*W, 2] -> [B, N, 2]
            query_pos = bev_pos.gather(
                index=top_proposals_index[:, None, :].permute(0, 2, 1).expand(
                    -1, -1, bev_pos.shape[-1]),
                dim=1)
        else:
            # [BS, C, num_proposals]
            query_feat = self.query_feat.repeat(batch_size, 1, 1)
            # [BS, num_proposals, 2]
            base_xyz = self.query_pos.repeat(batch_size, 1,
                                             1).to(lidar_feat.device)

        #################################
        # transformer decoder layer (LiDAR feature as K,V)
        #################################

        # multi transformer decoder output
        ret_dicts = []
        for i in range(self.num_decoder_layers):

            # Transformer Decoder Layer, [B, Q_c, N]
            query_feat = self.decoder[i](query_feat, lidar_feat_flatten,
                                         query_pos, bev_pos)

            # Prediction
            res_layer = self.prediction_heads[i](query_feat)
            # regression center is offset to grid center
            res_layer['center'] = res_layer['center'] + query_pos.permute(
                0, 2, 1)

            first_res_layer = res_layer
            ret_dicts.append(res_layer)

            # for next level positional embedding
            query_pos = res_layer['center'].detach().clone().permute(0, 2, 1)

        if self.initialize_by_heatmap:
            # [B, Cls_C, N]
            ret_dicts[0]['query_heatmap_score'] = heatmap.gather(
                index=top_proposals_index[:, None, :].expand(
                    -1, self.num_classes, -1),
                dim=-1)
            # [B, Cls_C, H, W]
            ret_dicts[0]['dense_heatmap'] = dense_heatmap

        if self.auxiliary is False:
            # only return the results of last decoder layer
            return [ret_dicts[-1]]

        # return all the layer's results for auxiliary superivison
        # 'dense_heatmap', 'dense_heatmap_old', 'query_heatmap_score' only return first decoder layer results
        # 'heatmap' 'center' 'height' 'dim' 'dim' return all, and cat with dim=-1
        new_res = {}
        for key in ret_dicts[0].keys():
            if key not in [
                    'dense_heatmap', 'dense_heatmap_old', 'query_heatmap_score'
            ]:
                # [B, C, N] -> [B, C, N*L]
                new_res[key] = torch.cat(
                    [ret_dict[key] for ret_dict in ret_dicts], dim=-1)
            else:
                new_res[key] = ret_dicts[0][key]
        return [new_res]

    def forward(self, feats):
        """Forward pass.

        Args:
            feats (list[torch.Tensor]): Multi-level features, e.g.,
                features produced by FPN.

        Returns:
            tuple(list[dict]): Output results. first index by level, second index by layer
        """
        assert len(feats) == 1, "only support one level features."
        res = multi_apply(self.forward_single, feats)
        return res

    def get_targets(self, gt_bboxes_3d, gt_labels_3d, preds_dict):
        """Generate training targets.

        Args:
            gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes.
            gt_labels_3d (torch.Tensor): Labels of boxes.
            preds_dicts (tuple of dict): first index by layer (default 1)
        Returns:
            tuple[torch.Tensor]: Tuple of target including \
                the following results in order.

                - torch.Tensor: classification target.  [BS, num_proposals]
                - torch.Tensor: classification weights (mask)  [BS, num_proposals]
                - torch.Tensor: regression target. [BS, num_proposals, 8]
                - torch.Tensor: regression weights. [BS, num_proposals, 8]
        """
        # change preds_dict into list of dict (index by batch_id)
        # preds_dict[0]['center'].shape [bs, 3, num_proposal]

        # level based to batch based
        # TODO:!!! hard to read, level to batch but still keep the
        list_of_pred_dict = []
        for batch_idx in range(len(gt_bboxes_3d)):
            pred_dict = {}
            for key in preds_dict[0].keys():
                pred_dict[key] = preds_dict[0][key][batch_idx:batch_idx + 1]
            list_of_pred_dict.append(pred_dict)

        assert len(gt_bboxes_3d) == len(list_of_pred_dict)

        res_tuple = multi_apply(self.get_targets_single, gt_bboxes_3d,
                                gt_labels_3d, list_of_pred_dict,
                                np.arange(len(gt_labels_3d)))
        labels = torch.cat(res_tuple[0], dim=0)
        label_weights = torch.cat(res_tuple[1], dim=0)
        bbox_targets = torch.cat(res_tuple[2], dim=0)
        bbox_weights = torch.cat(res_tuple[3], dim=0)
        ious = torch.cat(res_tuple[4], dim=0)
        num_pos = np.sum(res_tuple[5])
        matched_ious = np.mean(res_tuple[6])
        if self.initialize_by_heatmap:
            heatmap = torch.cat(res_tuple[7], dim=0)
            return labels, label_weights, bbox_targets, bbox_weights, ious, num_pos, matched_ious, heatmap
        else:
            return labels, label_weights, bbox_targets, bbox_weights, ious, num_pos, matched_ious

    def get_targets_single(self, gt_bboxes_3d, gt_labels_3d, preds_dict,
                           batch_idx):
        """Generate training targets for a single sample.

        Args:
            gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes.
            gt_labels_3d (torch.Tensor): Labels of boxes.
            preds_dict (dict): dict of prediction result for a single sample
        Returns:
            tuple[torch.Tensor]: Tuple of target including \
                the following results in order.

                - torch.Tensor: classification target.  [1, num_proposals]
                - torch.Tensor: classification weights (mask)  [1, num_proposals]
                - torch.Tensor: regression target. [1, num_proposals, 8]
                - torch.Tensor: regression weights. [1, num_proposals, 8]
                - torch.Tensor: iou target. [1, num_proposals]
                - int: number of positive proposals
        """
        num_proposals = preds_dict['center'].shape[-1]

        # get pred boxes, carefully ! donot change the network outputs
        score = copy.deepcopy(preds_dict['heatmap'].detach())
        center = copy.deepcopy(preds_dict['center'].detach())
        height = copy.deepcopy(preds_dict['height'].detach())
        dim = copy.deepcopy(preds_dict['dim'].detach())
        rot = copy.deepcopy(preds_dict['rot'].detach())

        # decode the prediction to real world metric bbox
        boxes_dict = self.bbox_coder.decode(score, rot, dim, center, height,
                                            None)
        # [L*N, C]
        bboxes_tensor = boxes_dict[0]['bboxes']
        gt_bboxes_tensor = gt_bboxes_3d.tensor.to(score.device)
        # each layer should do label assign seperately.
        if self.auxiliary:
            num_layer = self.num_decoder_layers
        else:
            num_layer = 1

        assign_result_list = []
        for idx_layer in range(num_layer):
            bboxes_tensor_layer = bboxes_tensor[self.num_proposals *
                                                idx_layer:self.num_proposals *
                                                (idx_layer + 1), :]
            score_layer = score[..., self.num_proposals *
                                idx_layer:self.num_proposals * (idx_layer + 1)]

            if self.train_cfg.assigner.type == 'HungarianAssigner3D':
                assign_result = self.bbox_assigner.assign(
                    bboxes_tensor_layer, gt_bboxes_tensor, gt_labels_3d,
                    score_layer, self.train_cfg)
            elif self.train_cfg.assigner.type == 'HeuristicAssigner':
                assign_result = self.bbox_assigner.assign(
                    bboxes_tensor_layer, gt_bboxes_tensor, None, gt_labels_3d,
                    self.query_labels[batch_idx])
            else:
                raise NotImplementedError
            assign_result_list.append(assign_result)

        # combine assign result of each layer
        assign_result_ensemble = AssignResult(
            num_gts=sum([res.num_gts for res in assign_result_list]),
            gt_inds=torch.cat([res.gt_inds for res in assign_result_list]),
            max_overlaps=torch.cat(
                [res.max_overlaps for res in assign_result_list]),
            labels=torch.cat([res.labels for res in assign_result_list]),
        )
        sampling_result = self.bbox_sampler.sample(assign_result_ensemble,
                                                   bboxes_tensor,
                                                   gt_bboxes_tensor)
        pos_inds = sampling_result.pos_inds
        neg_inds = sampling_result.neg_inds
        assert len(pos_inds) + len(neg_inds) == num_proposals

        # create target for loss computation
        bbox_targets = torch.zeros([num_proposals, self.bbox_coder.code_size
                                    ]).to(center.device)
        bbox_weights = torch.zeros([num_proposals, self.bbox_coder.code_size
                                    ]).to(center.device)
        ious = assign_result_ensemble.max_overlaps
        ious = torch.clamp(ious, min=0.0, max=1.0)
        labels = bboxes_tensor.new_zeros(num_proposals, dtype=torch.long)
        label_weights = bboxes_tensor.new_zeros(
            num_proposals, dtype=torch.long)

        if gt_labels_3d is not None:  # default label is -1
            labels += self.num_classes

        # both pos and neg have classification loss, only pos has regression and iou loss
        if len(pos_inds) > 0:
            pos_bbox_targets = self.bbox_coder.encode(
                sampling_result.pos_gt_bboxes)

            bbox_targets[pos_inds, :] = pos_bbox_targets
            bbox_weights[pos_inds, :] = 1.0

            if gt_labels_3d is None:
                labels[pos_inds] = 1
            else:
                labels[pos_inds] = gt_labels_3d[
                    sampling_result.pos_assigned_gt_inds]
            if self.train_cfg.get('pos_weight', -1) <= 0:
                label_weights[pos_inds] = 1.0
            else:
                label_weights[pos_inds] = self.train_cfg.get('pos_weight')

        if len(neg_inds) > 0:
            label_weights[neg_inds] = 1.0

        # # compute dense heatmap targets
        if self.initialize_by_heatmap:
            device = labels.device
            grid_size = torch.tensor(self.train_cfg['grid_size'])
            pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
            voxel_size = torch.tensor(self.train_cfg['voxel_size'])
            feature_map_size = grid_size[:2] // self.train_cfg[
                'out_size_factor']  # [x_len, y_len]
            heatmap = gt_bboxes_tensor.new_zeros(self.num_classes,
                                                 feature_map_size[1],
                                                 feature_map_size[0])
            for idx in range(len(gt_bboxes_tensor)):
                cls_id = gt_labels_3d[idx]
                width = gt_bboxes_tensor[idx][3]
                length = gt_bboxes_tensor[idx][4]
                width = width / voxel_size[0] / self.train_cfg[
                    'out_size_factor']
                length = length / voxel_size[1] / self.train_cfg[
                    'out_size_factor']
                if width > 0 and length > 0:
                    radius = gaussian_radius(
                        (length, width),
                        min_overlap=self.train_cfg['gaussian_overlap'])
                    radius = max(self.train_cfg['min_radius'][cls_id],
                                 int(radius))
                    x, y = gt_bboxes_tensor[idx][0], gt_bboxes_tensor[idx][1]

                    coor_x = (
                        x - pc_range[0]
                    ) / voxel_size[0] / self.train_cfg['out_size_factor']
                    coor_y = (
                        y - pc_range[1]
                    ) / voxel_size[1] / self.train_cfg['out_size_factor']

                    center = torch.tensor([coor_x, coor_y],
                                          dtype=torch.float32,
                                          device=device)
                    center_int = center.to(torch.int32)
                    draw_heatmap_gaussian(heatmap[cls_id], center_int, radius)

            mean_iou = ious[pos_inds].sum() / max(len(pos_inds), 1)
            return labels[None], label_weights[None], bbox_targets[
                None], bbox_weights[None], ious[None], int(
                    pos_inds.shape[0]), float(mean_iou), heatmap[None]

        else:
            mean_iou = ious[pos_inds].sum() / max(len(pos_inds), 1)
            return labels[None], label_weights[None], bbox_targets[
                None], bbox_weights[None], ious[None], int(
                    pos_inds.shape[0]), float(mean_iou)

    @force_fp32(apply_to=('preds_dicts'))
    def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
        """Loss function for CenterHead.

        Args:
            gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
                truth gt boxes.
            gt_labels_3d (list[torch.Tensor]): Labels of boxes.
            preds_dicts (list[list[dict]]): Output of forward function.

        Returns:
            dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
        """
        if self.initialize_by_heatmap:
            labels, label_weights, bbox_targets, bbox_weights, ious, num_pos, matched_ious, heatmap = self.get_targets(
                gt_bboxes_3d, gt_labels_3d, preds_dicts[0])
        else:
            labels, label_weights, bbox_targets, bbox_weights, ious, num_pos, matched_ious = self.get_targets(
                gt_bboxes_3d, gt_labels_3d, preds_dicts[0])
        if hasattr(self, 'on_the_image_mask'):
            label_weights = label_weights * self.on_the_image_mask
            bbox_weights = bbox_weights * self.on_the_image_mask[:, :, None]
            num_pos = bbox_weights.max(-1).values.sum()
        preds_dict = preds_dicts[0][0]
        loss_dict = dict()

        if self.initialize_by_heatmap:
            # compute heatmap loss
            loss_heatmap = self.loss_heatmap(
                clip_sigmoid(preds_dict['dense_heatmap']),
                heatmap,
                avg_factor=max(heatmap.eq(1).float().sum().item(), 1))
            loss_dict['loss_heatmap'] = loss_heatmap

        # compute loss for each layer
        for idx_layer in range(
                self.num_decoder_layers if self.auxiliary else 1):
            prefix = f'layer_{idx_layer}'

            layer_labels = labels[..., idx_layer *
                                  self.num_proposals:(idx_layer + 1) *
                                  self.num_proposals].reshape(-1)
            layer_label_weights = label_weights[..., idx_layer *
                                                self.num_proposals:(idx_layer +
                                                                    1) *
                                                self.num_proposals].reshape(-1)
            layer_score = preds_dict['heatmap'][..., idx_layer *
                                                self.num_proposals:(idx_layer +
                                                                    1) *
                                                self.num_proposals]
            layer_cls_score = layer_score.permute(0, 2, 1).reshape(
                -1, self.num_classes)
            # TODO: num_pos is not correct
            layer_loss_cls = self.loss_cls(
                layer_cls_score,
                layer_labels,
                layer_label_weights,
                avg_factor=max(num_pos, 1))

            layer_center = preds_dict['center'][..., idx_layer *
                                                self.num_proposals:(idx_layer +
                                                                    1) *
                                                self.num_proposals]
            layer_height = preds_dict['height'][..., idx_layer *
                                                self.num_proposals:(idx_layer +
                                                                    1) *
                                                self.num_proposals]
            layer_rot = preds_dict['rot'][..., idx_layer *
                                          self.num_proposals:(idx_layer + 1) *
                                          self.num_proposals]
            layer_dim = preds_dict['dim'][..., idx_layer *
                                          self.num_proposals:(idx_layer + 1) *
                                          self.num_proposals]
            # [BS, num_proposals, code_size]
            preds = torch.cat(
                [layer_center, layer_height, layer_dim, layer_rot],
                dim=1).permute(0, 2, 1)

            code_weights = self.train_cfg.get('code_weights', None)
            layer_bbox_weights = bbox_weights[:, idx_layer *
                                              self.num_proposals:(idx_layer +
                                                                  1) *
                                              self.num_proposals, :]
            layer_reg_weights = layer_bbox_weights * layer_bbox_weights.new_tensor(
                code_weights)
            layer_bbox_targets = bbox_targets[:, idx_layer *
                                              self.num_proposals:(idx_layer +
                                                                  1) *
                                              self.num_proposals, :]
            layer_loss_bbox = self.loss_bbox(
                preds,
                layer_bbox_targets,
                layer_reg_weights,
                avg_factor=max(num_pos, 1))

            # layer_iou = preds_dict['iou'][..., idx_layer*self.num_proposals:(idx_layer+1)*self.num_proposals].squeeze(1)
            # layer_iou_target = ious[..., idx_layer*self.num_proposals:(idx_layer+1)*self.num_proposals]
            # layer_loss_iou = self.loss_iou(layer_iou, layer_iou_target, layer_bbox_weights.max(-1).values, avg_factor=max(num_pos, 1))

            loss_dict[f'loss_{prefix}_cls'] = layer_loss_cls
            loss_dict[f'loss_{prefix}_bbox'] = layer_loss_bbox
            # loss_dict[f'{prefix}_loss_iou'] = layer_loss_iou

        loss_dict[f'matched_ious'] = layer_loss_cls.new_tensor(matched_ious)

        return loss_dict

    def get_bboxes(self,
                   preds_dicts,
                   img_metas,
                   img=None,
                   rescale=False,
                   for_roi=False):
        """Generate bboxes from bbox head predictions.

        Args:
            preds_dicts (tuple[list[dict]]): Prediction results.

        Returns:
            list[list[dict]]: Decoded bbox, scores and labels for each layer & each batch
        """
        rets = []
        for layer_id, preds_dict in enumerate(preds_dicts):
            batch_size = preds_dict[0]['heatmap'].shape[0]
            batch_score = preds_dict[0]['heatmap'][
                ..., -self.num_proposals:].sigmoid()
            # if self.loss_iou.loss_weight != 0:
            #    batch_score = torch.sqrt(batch_score * preds_dict[0]['iou'][..., -self.num_proposals:].sigmoid())
            one_hot = F.one_hot(
                self.query_labels,
                num_classes=self.num_classes).permute(0, 2, 1)
            batch_score = batch_score * preds_dict[0][
                'query_heatmap_score'] * one_hot

            batch_center = preds_dict[0]['center'][..., -self.num_proposals:]
            batch_height = preds_dict[0]['height'][..., -self.num_proposals:]
            batch_dim = preds_dict[0]['dim'][..., -self.num_proposals:]
            batch_rot = preds_dict[0]['rot'][..., -self.num_proposals:]
            batch_vel = None

            temp = self.bbox_coder.decode(
                batch_score,
                batch_rot,
                batch_dim,
                batch_center,
                batch_height,
                batch_vel,
                filter=True)

            self.tasks = [
                dict(
                    num_class=1, class_names=['Car'], indices=[0], radius=0.7),
                dict(
                    num_class=1,
                    class_names=['Pedestrian'],
                    indices=[1],
                    radius=0.7),
                dict(
                    num_class=1,
                    class_names=['Cyclist'],
                    indices=[2],
                    radius=0.7),
            ]

            ret_layer = []
            for i in range(batch_size):
                boxes3d = temp[i]['bboxes']
                scores = temp[i]['scores']
                labels = temp[i]['labels']
                ## adopt circle nms for different categories
                if self.test_cfg['nms_type'] != None:
                    keep_mask = torch.zeros_like(scores)
                    for task in self.tasks:
                        task_mask = torch.zeros_like(scores)
                        for cls_idx in task['indices']:
                            task_mask += labels == cls_idx
                        task_mask = task_mask.bool()
                        if task['radius'] > 0:
                            if self.test_cfg['nms_type'] == 'circle':
                                boxes_for_nms = torch.cat([
                                    boxes3d[task_mask][:, :2],
                                    scores[:, None][task_mask]
                                ],
                                                          dim=1)
                                task_keep_indices = torch.tensor(
                                    circle_nms(
                                        boxes_for_nms.detach().cpu().numpy(),
                                        task['radius'],
                                    ))
                            else:
                                boxes_for_nms = xywhr2xyxyr(
                                    img_metas[i]['box_type_3d'](
                                        boxes3d[task_mask][:, :7], 7).bev)
                                top_scores = scores[task_mask]
                                task_keep_indices = nms_gpu(
                                    boxes_for_nms,
                                    top_scores,
                                    thresh=task['radius'],
                                    pre_maxsize=self.test_cfg['pre_maxsize'],
                                    post_max_size=self.
                                    test_cfg['post_maxsize'],
                                )
                        else:
                            task_keep_indices = torch.arange(task_mask.sum())
                        if task_keep_indices.shape[0] != 0:
                            keep_indices = torch.where(
                                task_mask != 0)[0][task_keep_indices]
                            keep_mask[keep_indices] = 1
                    keep_mask = keep_mask.bool()
                    ret = dict(
                        bboxes=boxes3d[keep_mask],
                        scores=scores[keep_mask],
                        labels=labels[keep_mask])
                else:  # no nms
                    ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
                ret_layer.append(ret)
            rets.append(ret_layer)
        assert len(rets) == 1
        assert len(rets[0]) == 1
        res = [[
            img_metas[0]['box_type_3d'](
                rets[0][0]['bboxes'], box_dim=rets[0][0]['bboxes'].shape[-1]),
            rets[0][0]['scores'], rets[0][0]['labels'].int()
        ]]
        return res
``
[20220513_150959.log](https://github.com/XuyangBai/TransFusion/files/8704778/20220513_150959.log)
![10203656353524179475_7625_000_7645_000_1000000_pred](https://user-images.githubusercontent.com/7677343/168720001-0bbc4578-2108-4b2f-b8e7-a12b5bd96486.jpg)
`
XuyangBai commented 2 years ago

Could you briefly summarize what you changed so that I can help you find the potential problems? And the config you provide and the log are inconsistent, i.e. you set num_proposals=100 in config but the log shows num_proposals=300.. This is a critical hyperparameter for TransFusion because we need the number of proposals to be larger than that of the ground truth, otherwise, the label assignment step will be quite noisy. Please use 300 or more num_proposals for waymo dataset.

Why nuscenes loss_bbox weights is 0.25, but waymo is 2.0, its make the waymo box loss very big, over 10.0 at last.

We set the loss_bbox of waymo to 2 following CenterPoint, because on Waymo the network should care more about the bbox prediction for higher mAP calculated using 3D IoU (and also the classification is relatively easy). The absolute value of bbox loss does not matters.