mit-han-lab / torchsparse

[MICRO'23, MLSys'22] TorchSparse: Efficient Training and Inference Framework for Sparse Convolution on GPUs.
https://torchsparse.mit.edu
MIT License
1.16k stars 132 forks source link

[Performence] <Question about inference performance> #256

Closed galiyu closed 6 months ago

galiyu commented 8 months ago

Is there an existing issue for this?

Current Behavior

I tried to test a network and compared it with spconv. The value of the network structure (in_channel, ot_channel) is (4,16)(16,32)(32,64)(64,128)(128,128)(128,256)(256,256) The experimental data used has a total of 30,000 points, voxel size is 0.1, and batch is 1.

The inference phase of the entire network takes about 15ms. However, spconv with denser data runs in a time of2.6ms. This makes me confused

Expected Behavior

No response

Environment

- GCC:9.4.0
- NVCC:11.8
- PyTorch:2.0.0+cu118
- PyTorch CUDA:11.8
- TorchSparse:2.1.0

Anything else?

No response

jiweibo commented 8 months ago

I encountered similar problem.

The network is as follows:

subm_sparse_conv3d -> subm_sparse_conv3d(block) -> nonsubm_sparse_conv3d -> subm_sparse_conv3d(block) -> nonsubm_sparse_conv3d -> subm_sparse_conv3d(block) -> nonsubm_sparse_conv3d -> subm_sparse_conv3d(block) -> nonsubm_sparse_conv3d.

The network code:

class SubmSpConvBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size, stride, padding, dilation, bias
    ):
        super().__init__()
        self.sp1 = spnn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias,
            transposed=False,
            generative=False,
            config=None,
        )
        self.sp2 = spnn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias,
            transposed=False,
            generative=False,
            config=None,
        )
        self.sp3 = spnn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias,
            transposed=False,
            generative=False,
            config=None,
        )
        self.sp4 = spnn.Conv3d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias,
            transposed=False,
            generative=False,
            config=None,
        )

    def forward(self, x: SparseTensor):
        y = x + self.sp2(F.relu(self.sp1(x)))
        y = F.relu(y)
        y = y + self.sp4(F.relu(self.sp3(y)))
        return F.relu(y)

class SparseModelTest(nn.Module):
    def __init__(self):
        super().__init__()
        # TODO(wilber): tune config.
        self.sparse_head = nn.Sequential(
            spnn.Conv3d(
                in_channels=5,
                out_channels=16,
                kernel_size=(3, 3, 3),
                stride=(1, 1, 1),
                padding=(1, 1, 1),
                dilation=1,
                bias=True,
                transposed=False,
                generative=False,
                config=None,
            ),
            spnn.ReLU(),
        )
        self.block1 = SubmSpConvBlock(16, 16, 3, 1, 1, 1, True)
        self.sp1 = spnn.Conv3d(
            in_channels=16,
            out_channels=32,
            kernel_size=(3, 3, 3),
            stride=(2, 2, 2),
            padding=(1, 1, 1),
            dilation=(1, 1, 1),
            bias=True,
            transposed=False,
            generative=False,
            config=None,
        )
        self.block2 = SubmSpConvBlock(32, 32, 3, 1, 1, 1, True)
        self.sp2 = spnn.Conv3d(
            in_channels=32,
            out_channels=64,
            kernel_size=(3, 3, 3),
            stride=(2, 2, 2),
            padding=(1, 1, 1),
            dilation=(1, 1, 1),
            bias=True,
            transposed=False,
            generative=False,
            config=None,
        )
        self.block3 = SubmSpConvBlock(64, 64, 3, 1, 1, 1, True)
        self.sp3 = spnn.Conv3d(
            in_channels=64,
            out_channels=128,
            kernel_size=(3, 3, 3),
            stride=(2, 2, 2),
            padding=(1, 1, 1),
            dilation=(1, 1, 1),
            bias=True,
            transposed=False,
            generative=False,
            config=None,
        )
        self.block4 = SubmSpConvBlock(128, 128, 3, 1, 1, 1, True)
        self.sp4 = spnn.Conv3d(
            in_channels=128,
            out_channels=128,
            kernel_size=(3, 1, 1),
            stride=(2, 1, 1),
            padding=(0, 0, 0),
            dilation=(1, 1, 1),
            bias=True,
            transposed=False,
            generative=False,
            config=None,
        )

    def forward(self, x: SparseTensor):
        x = self.sparse_head(x)
        x = self.block1(x)
        x = F.relu(self.sp1(x))
        x = self.block2(x)
        x = F.relu(self.sp2(x))
        x = self.block3(x)
        x = F.relu(self.sp3(x))
        x = self.block4(x)
        x = F.relu(self.sp4(x))
        return x

perf (us)

The input feats shape is (260000, 5), coords shape is (260000, 4). layer libspconv torchsparse++
subm sparse_conv3d 597us 897us
subm block1 1020.68 1056.772
nonsubm sparse_conv3d 1618.38 1627.844
subm block2 2294.10 4908.721
nonsubm sparse_conv3d 1842.90 1790.951
subm block3 2285.87 5470.211
nonsubm sparse_conv3d 1160.93 1401.474
subm block4 3441.92 5821.803
nonsubm sparse_conv3d 757.03 515.016

The e2e time, libspconv is 15ms, but torchsparse++ is 24ms(I tried tune, but no effect). There is a big gap between the performance and the description in the paper.

Environment

- NVCC: 11.8
- PyTorch: 2.1.0+cu118
- TorchSparse commit id: afa2e3ba4be657a51f9112c8744592758ac06935
ys-2020 commented 8 months ago

Hi @jiweibo , could you please provide more information about your benchmark setting? For example, what is the GPU you have used? What is the coordinate distribution of your input data? And how did you measure the speed? Since your observations are not consistent with the experimental results on our machines.

Therefore, we kindly request you provide more details about your benchmarking, and a snippet code would be appreciated, so that we can investigate the problem together! Thanks!

jiweibo commented 8 months ago

Hi @jiweibo , could you please provide more information about your benchmark setting? For example, what is the GPU you have used? What is the coordinate distribution of your input data? And how did you measure the speed? Since your observations are not consistent with the experimental results on our machines.

Therefore, we kindly request you provide more details about your benchmarking, and a snippet code would be appreciated, so that we can investigate the problem together! Thanks!

Okay, the existing system is written in C++, I will write a python test and provide it, thanks.