traveller59 / spconv

Spatial Sparse Convolution Library
Apache License 2.0
1.89k stars 366 forks source link

Downsample error when building resnet with SubMConv3d #683

Open ZzTodd22 opened 9 months ago

ZzTodd22 commented 9 months ago

Here is my code:

def BasicConv3d(in_planes, out_planes, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0),
                bias=False, convstyle='SubMConv3d'):
    """3x3 convolution with padding"""
    if convstyle=='SparseConv3d':
        return spconv.SparseConv3d(in_planes, out_planes, kernel_size=kernel_size,
                         stride=stride, padding=padding, bias=bias)
    else:
        return spconv.SubMConv3d(in_planes, out_planes, kernel_size=kernel_size,
                         stride=stride, padding=padding, bias=bias)
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, padding=0, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = BasicConv3d(inplanes, planes, kernel_size=(1, 1, 1),
                                 stride=(1, 1, 1), padding=(0, 0, 0), bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = BasicConv3d(planes, planes, kernel_size=(3, 3, 3),
                                 stride=stride, padding=padding, bias=False, convstyle='SparseConv3d')
        self.bn2 = nn.BatchNorm1d(planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = BasicConv3d(planes, planes * self.expansion, kernel_size=(1, 1, 1),
                                 stride=(1, 1, 1), padding=(0, 0, 0), bias=False)
        self.bn3 = nn.BatchNorm1d(planes * self.expansion)
        self.relu3 = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = out.replace_feature(self.bn1(out.features))
        out = out.replace_feature(self.relu1(out.features))

        out = self.conv2(out)
        out = out.replace_feature(self.bn2(out.features))
        out = out.replace_feature(self.relu2(out.features))

        out = self.conv3(out)
        out = out.replace_feature(self.bn3(out.features))

        if self.downsample is not None:
            identity = self.downsample(x)
        # out1 = out.features
        out = out.replace_feature(out.features + identity.features)
        out = out.replace_feature(self.relu3(out.features))
        # visualize_feature_map(out, "out")
        # plt.show()
        return out

where downsample is defined as follows:

            downsample = spconv.SparseSequential(
                BasicConv3d(self.inplanes, planes * block.expansion, kernel_size=(1, 1, 1),
                            stride=(self.t_s if stride == 2 else 1, stride, stride), convstyle='SparseConv3d'),
                nn.BatchNorm1d(planes * block.expansion),
            )

There is this paradoxical problem: if SubMConv3d is used, downsampling is not possible, and if SparseConv3d is used for downsampling, i.e., conv2 uses SparseConv3d, it is not possible to guarantee that the indices keep the same.