hailanyi / VirConv

Virtual Sparse Convolution for Multimodal 3D Object Detection
https://arxiv.org/abs/2303.02314
Apache License 2.0
276 stars 39 forks source link

Code for StVD(Stochastic Voxel Discard) #28

Closed xuwentai closed 1 year ago

xuwentai commented 1 year ago

Thank you for your wonderful work. I would like to carefully understand how the STVD module runs. I have read your code, but I am not very familiar with it. I hope you can explain: Why sampled_sum should we express it this way: sampled_sum = points_num_acc + i * this_points.shape[0] and the meaning of this code: if sampled_sum / all_points_num < rate: position = i distant_points_num_acc = points_num_acc

VirConv/tree/master/pcdet/datasets/dataset.py
def partition(self, points, num=10, max_dis=60, rate=0.2):
        """
        partition the points into several bins.
        """

        points_list = []
        inter = max_dis / num

        all_points_num = points.shape[0]

        points_num_acc = 0

        position = num - 1

        distant_points_num_acc = 0

        for i in range(num):
            i = num - i - 1
            if i == num - 1:
                min_mask = points[:, 0] >= inter * i
                this_points = points[min_mask]

                points_num_acc += this_points.shape[0]

                sampled_sum = points_num_acc + i * this_points.shape[0]

                if sampled_sum / all_points_num < rate:
                    position = i
                    distant_points_num_acc = points_num_acc

                points_list.append(this_points)
            else:
                min_mask = points[:, 0] >= inter * i
                max_mask = points[:, 0] < inter * (i + 1)
                mask = min_mask * max_mask
                this_points = points[mask]

                points_num_acc += this_points.shape[0]

                sampled_sum = points_num_acc + i * this_points.shape[0]

                if sampled_sum / all_points_num < rate:
                    position = i
                    distant_points_num_acc = points_num_acc

                points_list.append(this_points)

        if position <= 0:
            position = 0

        return points_list, position, distant_points_num_acc

    def input_point_discard(self, points, bin_num=2, rate=0.8):
        """
        discard points by a bin-based sampling.
        """
        retain_rate = 1 - rate
        parts, pos, distant_points_num_acc = self.partition(points, num=bin_num, rate=retain_rate)

        output_pts_num = int(points.shape[0] * retain_rate)

        pts_per_bin_num = int((output_pts_num-distant_points_num_acc)/(pos+0.0001))

        for i in range(len(parts) - pos, len(parts)):

            if parts[i].shape[0] > pts_per_bin_num:
                rands = np.random.permutation(parts[i].shape[0])
                parts[i] = parts[i][rands[:pts_per_bin_num]]

        return np.concatenate(parts)
hailanyi commented 1 year ago

sampled_sum = points_num_acc + i * this_points.shape[0]

Where the points_num_acc denotes the far away points number, the i*this_points.shape[0] calculates the points number of nearby bins when each nearby bin samples same points as i th bin. The sampled_sum calculates the points number when applying bin-based sampling.

if sampled_sum / all_points_num < rate:
position = i

The distant bins will keep all points, only nearby bins should be sampled. The position is a bin index threshold to determine which bins need to be sampled.