Pointcept / PointTransformerV2

[NeurIPS'22] An official PyTorch implementation of PTv2.
357 stars 25 forks source link

How does your method handle empty voxels? #5

Closed xiaobaitu123344 closed 1 year ago

Gofinge commented 1 year ago

PTv2 is a point-based method. What's empty voxels meaning in your question?

xiaobaitu123344 commented 1 year ago

Isn't the partition-based pooling in your paper based on voxels?

Gofinge commented 1 year ago

No, the definition of partition-based pooling is: separating a point cloud into non-overlapping partitions, and fusion points share the same partition. For the implementation of grid pooling, we also compute which grid partitions each point belongs to and then fuse them. The implement code is attached below.

From my perspective, there is not much difference between voxel-based and point-based methods. Maybe voxels is just a kind of point after grid sampling. and you might have also found that current point-based methods also apply voxelization in data augmentation for downsampling points.

from torch_geometric.nn.pool import voxel_grid
from torch_scatter import segment_csr

class GridPool(nn.Module):
    """
    Partition-based Pooling (Grid Pooling)
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 grid_size,
                 bias=False):
        super(GridPool, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.grid_size = grid_size

        self.fc = nn.Linear(in_channels, out_channels, bias=bias)
        self.norm = PointBatchNorm(out_channels)
        self.act = nn.ReLU(inplace=True)

    def forward(self, points, start=None):
        coord, feat, offset = points
        batch = offset2batch(offset)
        feat = self.act(self.norm(self.fc(feat)))
        start = segment_csr(coord, torch.cat([batch.new_zeros(1), torch.cumsum(batch.bincount(), dim=0)]),
                            reduce="min") if start is None else start
        cluster = voxel_grid(pos=coord - start[batch], size=self.grid_size, batch=batch, start=0)
        unique, cluster, counts = torch.unique(cluster, sorted=True, return_inverse=True, return_counts=True)
        _, sorted_cluster_indices = torch.sort(cluster)
        idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
        coord = segment_csr(coord[sorted_cluster_indices], idx_ptr, reduce="mean")
        feat = segment_csr(feat[sorted_cluster_indices], idx_ptr, reduce="max")
        batch = batch[idx_ptr[:-1]]
        offset = batch2offset(batch)
        return [coord, feat, offset], cluster
xiaobaitu123344 commented 1 year ago

Thank you for your patience, I did not answer in time because of the epidemic this time, I would like to apologize for not answering questions in time