CVMI-Lab / SPS-Conv

(NeurlPS 2022) Spatial Pruned Sparse Convolution for Efficient 3D Object Detection
Apache License 2.0
62 stars 6 forks source link

How to calculate to FLOPs #2

Closed LinyeLi60 closed 1 year ago

LinyeLi60 commented 1 year ago

Very nice work. I want to known how to get the FLOPs of model, so I write some code based on the calculate_gemm_flops, can you tell me if my code is correct? Thank you!

 def forward(self, batch_dict):
        """
        Args:
            batch_dict:
                batch_size: int
                vfe_features: (num_voxels, C)
                voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
        Returns:
            batch_dict:
                encoded_spconv_tensor: sparse tensor
        """
        voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords']
        batch_size = batch_dict['batch_size']
        input_sp_tensor = spconv.SparseConvTensor(
            features=voxel_features,
            indices=voxel_coords.int(),
            spatial_shape=self.sparse_shape,
            batch_size=batch_size
        )

        x = self.conv_input(input_sp_tensor)

        x_conv1 = self.conv1(x)
        x_conv2 = self.conv2(x_conv1)
        x_conv3 = self.conv3(x_conv2)
        x_conv4 = self.conv4(x_conv3)

        # for detection head
        # [200, 176, 5] -> [200, 176, 2]
        out = self.conv_out(x_conv4)
        in_out_channels = [[(4, 16), ], [(16, 16), ], [(16, 32), (32, 32), (32, 32)], [(32, 64), (64, 64), (64, 64)],
                           [(64, 64), (64, 64), (64, 64)], [(64, 128), ]]
        indice_keys = [['subm1'], ['subm1'], ['spconv2', 'subm2', 'subm2'], ['spconv3', 'subm3', 'subm3'],
                       ['spconv4', 'subm4', 'subm4'], ['spconv_down2']]
        backbone3d_flops = 0
        for out_tensor, in_out_channel_list, indice_keys_list in zip([x, x_conv1, x_conv2, x_conv3, x_conv4, out],
                                                                     in_out_channels, indice_keys):
            for (inchannel, outchannel), indice_key in zip(in_out_channel_list, indice_keys_list):
                backbone3d_flops += calculate_gemm_flops(out_tensor, indice_key, inchannel, outchannel)

        batch_dict['backbone3d_flops'] = backbone3d_flops
        batch_dict.update({
            'encoded_spconv_tensor': out,
            'encoded_spconv_tensor_stride': 8
        })
        batch_dict.update({
            'multi_scale_3d_features': {
                'x_conv1': x_conv1,
                'x_conv2': x_conv2,
                'x_conv3': x_conv3,
                'x_conv4': x_conv4,
            }
        })
        batch_dict.update({
            'multi_scale_3d_strides': {
                'x_conv1': 1,
                'x_conv2': 2,
                'x_conv3': 4,
                'x_conv4': 8,
            }
        })

        return batch_dict
Eddieeee-Liu commented 1 year ago

Sorry for the late relay, the function "calculate_gemm_flops" is used for calculating the flops of vanilla version of spconv. The code for calculating the flops of SPRS

def calculated_sprs_flops(x, indice_key, value_mask, in_channels, out_channels, stride, padding, dilation, kernel_size):
    indices = x.indices
    pair_indices = copy.deepcopy(x.indice_dict[indice_key].indice_pairs)
    conv_valid_mask = ((indices[:,1:] % 2).sum(1)==0)

    pre_spatial_shape = x.spatial_shape
    new_spatial_shape = []
    for i in range(3):
        size = (pre_spatial_shape[i] + 2 * padding[i] - dilation *
                (kernel_size - 1) - 1) // stride + 1
        if kernel_size == -1:
            new_spatial_shape.append(1)
        else:
            new_spatial_shape.append(size)

    coords = indices[:,1:][conv_valid_mask]
    spatial_indices = (coords[:, 0] >0) * (coords[:, 1] >0) * (coords[:, 2] >0)  * \
        (coords[:, 0] < new_spatial_shape[0]) * (coords[:, 1] < new_spatial_shape[1]) * (coords[:, 2] < new_spatial_shape[2])

    pair_indices_in = pair_indices[0] # [k**3, N]
    pair_indices_out = pair_indices[1] # [k**3, N]
    if conv_valid_mask.dtype == torch.bool:
        valid_position = torch.nonzero(conv_valid_mask).view(-1,)
    value_mask = torch.isin(pair_indices_in, value_mask)
    valid_position = valid_position[spatial_indices]
    mask = torch.isin(pair_indices_out, valid_position)
    mask = mask * value_mask
    mask = ~mask
    pair_indices_out[mask] = -1
    cur_flops = 2 * (pair_indices_out > -1).sum() * in_channels * out_channels - pair_indices_out.shape[1]
    return cur_flops

The way of getting the value_mask (SPRS):

out = self.combine_feature(x_im, x_nim, remove_repeat=True)
value_mask = None
if not self.training:
    value_mask = (out.features.sum(-1)!=0)
    value_mask = torch.nonzero(value_mask).view(-1,)

The code for calculating the flops of SPSS:

def calculate_spss_flops(x, indice_key, in_channels, out_channels, mask_position):
    if mask_position.dtype == torch.bool:
        mask_position = torch.nonzero(mask_position).view(-1,)
    pair_indices = copy.deepcopy(x.indice_dict[indice_key].indice_pairs)

    pair_indices_out = pair_indices[1] # [k**3, N]
    mask = torch.isin(pair_indices_out, mask_position)
    pair_indices_out[mask] = -1
    cur_flops = 2 * (pair_indices_out > -1).sum() * in_channels * out_channels - pair_indices_out.shape[1]
    return cur_flops
LinyeLi60 commented 1 year ago

Thank you for your detail reply! I will close this issue.

LinyeLi60 commented 1 year ago

For model trained by second_spss_ratio0.5_sprs_ratio0.5.yaml, I get different GFLOPs. Here is the code I used. For SpatialPrunedConvDownsample:

x_im, x_nim = self.gemerate_sparse_tensor(x, voxel_importance)
out = self.combine_feature(x_im, x_nim, remove_repeat=True)
value_mask = None
out = self.conv_block(out)
if not self.training:
    value_mask = (out.features.sum(-1) != 0)
    value_mask = torch.nonzero(value_mask).view(-1, )
    batch_dict['backbone3d_flops'] += calculated_sprs_flops(out, self.indice_key, value_mask,
                                                            self.in_channels,
                                                            self.out_channels,
                                                            self.stride,
                                                            self.padding,
                                                            self.dilation,
                                                            self.kernel_size)
pair_indices = None
out = self.reset_spatial_shape(out, batch_dict, pair_indices, value_mask)

For SpatialPrunedSubmConvBlock

mask_position = self.get_importance_mask(x, voxel_importance)

x = x.replace_feature(x.features * voxel_importance)
x_nim = x
x_im = self.conv_block(x)
if not self.training:
    batch_dict['backbone3d_flops'] += calculate_spss_flops(x_im,
                                                           self.indice_key,
                                                           self.in_channels,
                                                           self.out_channels,
                                                           mask_position)

out = self._combine_feature(x_im, x_nim, mask_position)

And I add the flops of conv_input and conv_out

batch_dict['backbone3d_flops'] += calculate_gemm_flops(x, 'subm1', 4, 16)
batch_dict['backbone3d_flops'] += calculate_gemm_flops(out, 'spconv_down2', 64, 128)

Finally, I get 2.74 GFLOPs on KITTI dataset.

Eddieeee-Liu commented 1 year ago

Actually, We set the downsample pruning rate as {0.7, 0.5, 0.3} for KITTI when we submitted our paper, we have mentioned in the paper, so this would cause the different results of GFLOPs

LinyeLi60 commented 1 year ago

All right. Thank you!