open-mmlab / mmdetection3d

OpenMMLab's next-generation platform for general 3D object detection.
https://mmdetection3d.readthedocs.io/en/latest/
Apache License 2.0
5.05k stars 1.5k forks source link

[Potential Performance Improvement] Hard voxelization could have a much faster implementation (at the cost of determinism). #894

Open zhanggefan opened 2 years ago

zhanggefan commented 2 years ago

Hi,

I noticed that the Cuda code of hard voxelization wraps some kernel code that is computationally inefficient.

  1. single-thread kernel launch https://github.com/open-mmlab/mmdetection3d/blob/d1123084973dd3a910760e77c499afad76d16cea/mmdet3d/ops/voxel/src/voxelization_cuda.cu#L276

  2. O(n^2) part -- loop in kernel. https://github.com/open-mmlab/mmdetection3d/blob/d1123084973dd3a910760e77c499afad76d16cea/mmdet3d/ops/voxel/src/voxelization_cuda.cu#L156

From my experience (By pr #318), voxelization can be viewed as a sparse hashing problem. Getting the sparse index can be solved by torch.unique at the complexity of O(nlogn), while the hash-reduce part can be implemented by atomicMax at the complexity of O(n). Just like the implementation in Apollo: https://github.com/ApolloAuto/apollo/blob/master/modules/perception/lidar/lib/detector/point_pillars_detection/preprocess_points_cuda.cu

From my observation, this part is the performance bottleneck for many lightweight detectors on many datasets. For example the CenterPoint with pillar encoder on NuScenes dataset: https://github.com/open-mmlab/mmdetection3d/blob/d1123084973dd3a910760e77c499afad76d16cea/configs/centerpoint/centerpoint_02pillar_second_secfpn_4x8_cyclic_20e_nus.py I managed to reimplement this part and it can reduce the training iteration time by more than a half, a huge improvement.

I also understand that your implementation is designed for determinism. Because hard voxelization involves points/voxels discarding, the ordering between points matters so that we could guarantee that always to discard the same portion of input points. In your implementation, the former points/voxels take precedence over the latter ones when discard points, thus your implementation is deterministic.

My reimplementation is non-deterministic because the hash-reduce kernel may be executed in any order, with no ordering between input points at runtime. Different runs for the same data can result in different portions of points being discarded.

If you think the performance gain is worth more than keeping the determinism, then my reimplementation is ready for PR.

Looking forward to your reply!

Tai-Wang commented 2 years ago

Yes, it really makes sense. I guess the deterministic property is not much important in the overall procedure.

Have you ever compare the performance yielded by these two ways of voxelization? If they are comparable, I think we can first add this way as an optional approach, and then test it with several times regression benchmark to guarantee it has no problem and performs better indeed. Finally if necessary, we can replace the old one. Do you agree or have any suggestions?

zhanggefan commented 2 years ago

I fully agree with you. But honestly, regression benchmark tests are really expensive (at least for me). I can provide you some statistics of the experiments I've done (faster training with no score drop of course), but it is hardly a "regression benchmark". I do hope the community can support us on it. If you don't mind, I would like to first clean up my code, make it an optional feature, and submit the PR.

zhanggefan commented 2 years ago

904

tianweiy commented 2 years ago

Hi all @zhanggefan @Tai-Wang , out of curiosity, can't we use something like this (https://github.com/tianweiy/CenterPoint/blob/44ebd63a81053e6fe85cd1228b644bab9d726507/det3d/models/readers/dynamic_voxel_encoder.py#L8) to do the voxelization ? It is relatively fast and doesn't need to write custome cuda code

zhanggefan commented 2 years ago

Hi, @tianweiy As for the dynamic voxelization part, I couldn't agree more. Using torch.unique to get the scatter index and something like torch_scatter.scatter to do the scatter reduction is the best solution in my option. The only drawback is that it adds additional dependency unless we implement the "scatter-reduce" by numba or torch extension.

But for hard voxelization part, things are different. Extra care shall be taken to handle the discarding process when there are more points inside a voxel than local threshold, or more voxels than global threshold. The points/voxels selection & discarding is full of dirty logic (forgive me but it really gives me a headache) and moving those logic to python code would make it much less efficient.

tianweiy commented 2 years ago

ic, yeah. additionally, scatter-reduce can also be implemented in PyTorch easily (at least for mean / sum)

# The following code are copied from pytorch_scatter https://github.com/rusty1s/pytorch_scatter
# Copyright (c) 2020 Matthias Fey <matthias.fey@tu-dortmund.de>
# MIT License 
from typing import Optional, Tuple
import torch 

@torch.jit.script
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
    if dim < 0:
        dim = other.dim() + dim
    if src.dim() == 1:
        for _ in range(dim):
            src = src.unsqueeze(0)
    for _ in range(other.dim()-src.dim()):
        src = src.unsqueeze(-1)
    src = src.expand_as(other)
    return src

@torch.jit.script
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    index = broadcast(index, src, dim)
    if out is None:
        size = list(src.size())
        if dim_size is not None:
            size[dim] = dim_size
        elif index.numel() == 0:
            size[dim] = 0
        else:
            size[dim] = int(index.max()) + 1
        out = torch.zeros(size, dtype=src.dtype, device=src.device)
        return out.scatter_add_(dim, index, src)
    else:
        return out.scatter_add_(dim, index, src)

@torch.jit.script
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
                 out: Optional[torch.Tensor] = None,
                 dim_size: Optional[int] = None) -> torch.Tensor:

    out = scatter_sum(src, index, dim, out, dim_size)
    dim_size = out.size(dim)

    index_dim = dim
    if index_dim < 0:
        index_dim = index_dim + src.dim()
    if index.dim() <= index_dim:
        index_dim = index.dim() - 1

    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, index_dim, None, dim_size)
    count.clamp_(1)
    count = broadcast(count, out, dim)
    if torch.is_floating_point(out):
        out.div_(count)
    else:
        assert 0 
        # out.floor_divide_(count)
    return out

I am actually not sure if we really need hard voxelization. I feel dynamic is both faster + simpler (we can probably do discarding of more voxels later)

Tai-Wang commented 2 years ago

@zhanggefan Oh don't worry. What I originally meant is that we will help you do the regression benchmark to further check it. Actually, we usually conduct it once before each release.

@tianweiy Great discussion. I just add one thing. As @zhanggefan points out, the current way should be a better one especially for hard voxelization. We just consider them as basic ops and would like to implement them in a consistent way that can be easily migrated (maybe to mmcv in the future). In addition, I guess the dynamic voxelization is not much stable in some cases (such as points are extremely dense) so it has not replaced the hard version. There is also space for further experiments.

tianweiy commented 2 years ago

I guess the dynamic voxelization is not much stable in some cases (such as points are extremely dense) so it has not replaced the hard version. <- I think at least on Waymo, dynamic voxelization performs about the same as hard?

The two differences I know are

  1. max number of point per voxel
  2. max number of voxel

for the first one, if you want to do some average, then using more points (e.g. all the points in dynamic voxelization) shouldn't make the result worse.

for the second one, if we still want to limit the max number of voxels (for speed consideration, etc..), we can do the indexing at the end

Are there any other differences that I miss?

Tai-Wang commented 2 years ago

I remember the previous experiments are on nuScenes (10-frame points concatenated), and the problem appears when a voxel contains very dense points. It will influence the voxelization efficiency. Besides, using more points is ok (both in terms of efficiency and implementation complexity) when computing mean value, but it can bring some trouble when we would like to extract features in other ways (more complicated voxel encoder).

tianweiy commented 2 years ago

thanks, the efficiency doesn't sound to be a problem in my dynamic voxelization implementation above (3ms for 3 frames Waymo point cloud 400k+ points if I remember correctly).

I agree it will be more complex to extract other features. I do think things like point pillar is doable just using torch_scatter / unique (not proposing a change, just an example)

zhanggefan commented 2 years ago

@tianweiy

Honestly, I've never used hard voxelization for point-pillars in production code. According to my experiments, it is indeed slow and the score is even less optimum compared to dynamic voxelization.

zhanggefan commented 2 years ago

@Tai-Wang

I have some more suggestions for dynamic voxelization. I suggest to devide the dynamic scatter op into 2 parts, one for indexing and the other one for scatter-reduction.

Dynamic voxelization, as what @tianweiy just pointed out, as a whole can be devided into 3 parts:

The core thought is, the "Scatter-Reduction" itself does not cost much time according to my observation, but "Indexing" does. torch.unique is the most computationally expensive op among the three.

Currently, in MMDet3D's implementation, the first part is the dynamic_voxelization, while the latter 2 parts are packed into dynamic_scatter. For some use cases where we have to repeatedly do the scatter-reduction for features that belongs to the same point cloud, we have to use dynamic_scatter multiple times, which means we have to do the "Indexing" part multiple times, which is wasteful.

All voxel encoders that use dynamic_scatter to get cluster centers features waste time on the redundant Indexing. An extreme use case is pillar-od, where we need both cluster centers (1st scatter-reduce) and covariance matrices (2nd scatter-reduce) for both BEV and cylindrical views, and we still need one final scatter-reduce for DynamicPillarFeatureNet before backbone -- 4 redundant indexing in total if we use dynamic_scatter.

Tai-Wang commented 2 years ago

@zhanggefan Yes, I agree with you. I think a more compatible way to improve it is to split this function into two sub-functions while keeping the original one? Then if we need to use only one of the sub-function, we do not need to conduct both of them? Is there any example that is similar to pillar-od in the currently supported models? (Then I will know how much influence it has on our current codes)

If you have preliminarily tested your improved code, could you please create a PR and we can help review or further enhance it? If not, maybe we will take it as a TODO item to improve in the future.

ZwwWayne commented 2 years ago

The original implementation of dynamic voxelization follows the similar strategy of hard voxelization, which still builds a big feature map with shape (num_voxels, max_num_voxels_per_point, C) for feature average and summation. In the case of nuScenes, sometimes there are more than 1k points in a pillar, thus the creation of this feature map will cause OOM with 32G V100. I am not very satisfied with the implementation in the initial release actually, though we do not find a suitable time to optimize it. But thanks to @zhanggefan we have better implementation now.

Therefore, any optimization suggestions/PRs about the voxelization part are welcomed. For example, if we find the pure torch version for dynamic voxelization has little memory cost and is faster yet simpler, we can consider that.

muzi2045 commented 2 years ago

When using Dynamic Voxelization, there must be more points in voxel or pillar around the vehicle. Maybe could try to using voxel downsample or FPS(farest point sample) to extract better points set to reprensent the feature in voxel or pillars, of course, the downsample will cost time depends on the leaf size. Only downsample the nearby voxels will be better.