zhulf0804 / PointPillars

A Simple PointPillars PyTorch Implementation for 3D LiDAR(KITTI) Detection.
MIT License
454 stars 112 forks source link

shape mismatch of the forward function of the Loss class #59

Open Afreshbird opened 10 months ago

Afreshbird commented 10 months ago

According to your notes, all of the parameters's shape at zero index should be consistent in Loss class's forward function. But according to my dubug results, the shape of passed parameters is not the same.

According to your notes and my understanding, I made the following two changes:

  1. In model/anchors.py def anchor_target(batched_anchors, batched_gt_bboxes, batched_gt_labels, assigners, nclasses):

    179    assigned_gt_labels_weights[pos_flag] = 1
    180    assigned_gt_labels_weights[neg_flag] = 1

    Be modified as: assigned_gt_labels_weights[pos_flag] = 1 After the revision, the shape at zero index of the bbox_cls_pred and batched_bbox_labels will be match with others variable. ( bbox_cls_pred and batched_bbox_labels at 109/180 and 111/182 row in train.py)

  2. In train.py

    108    num_cls_pos = (batched_bbox_labels < args.nclasses).sum()
    109    bbox_cls_pred = bbox_cls_pred[batched_label_weights > 0]
    110    batched_bbox_labels[batched_bbox_labels < 0] = args.nclasses
    111    batched_bbox_labels = batched_bbox_labels[batched_label_weights > 0]

    Be modified as:

          bbox_cls_pred = bbox_cls_pred[batched_label_weights > 0]
          batched_bbox_labels[batched_bbox_labels < 0] = args.nclasses
          batched_bbox_labels = batched_bbox_labels[batched_label_weights > 0]
          num_cls_pos = (batched_bbox_labels < args.nclasses).sum()

    Just switched positions. Because batched_bbox_labels maybe contains the value of -1, If num_cls_pos is assigned first, It can also lead to mismatched shapes.

And according to the issue of #38, I made the following change in train.py.

102     # sin(a - b) = sin(a)*cos(b) - cos(a)*sin(b)
103     bbox_pred[:, -1] = torch.sin(bbox_pred[:, -1].clone()) * torch.cos(batched_bbox_reg[:, -1].clone())
104     batched_bbox_reg[:, -1] = torch.cos(bbox_pred[:, -1].clone()) * torch.sin(batched_bbox_reg[:, -1].clone())

Be modified as:

# sin(a - b) = sin(a)*cos(b) - cos(a)*sin(b)
a = bbox_pred[:, -1].clone()
b = batched_bbox_reg[:, -1].clone()
bbox_pred[:, -1] = torch.sin(a) * torch.cos(b)
batched_bbox_reg[:, -1] = torch.cos(a) * torch.sin(b)

The same goes for line 173-175.