Closed LinyeLi60 closed 1 year ago
Sorry for the late relay, the function "calculate_gemm_flops" is used for calculating the flops of vanilla version of spconv. The code for calculating the flops of SPRS
def calculated_sprs_flops(x, indice_key, value_mask, in_channels, out_channels, stride, padding, dilation, kernel_size):
indices = x.indices
pair_indices = copy.deepcopy(x.indice_dict[indice_key].indice_pairs)
conv_valid_mask = ((indices[:,1:] % 2).sum(1)==0)
pre_spatial_shape = x.spatial_shape
new_spatial_shape = []
for i in range(3):
size = (pre_spatial_shape[i] + 2 * padding[i] - dilation *
(kernel_size - 1) - 1) // stride + 1
if kernel_size == -1:
new_spatial_shape.append(1)
else:
new_spatial_shape.append(size)
coords = indices[:,1:][conv_valid_mask]
spatial_indices = (coords[:, 0] >0) * (coords[:, 1] >0) * (coords[:, 2] >0) * \
(coords[:, 0] < new_spatial_shape[0]) * (coords[:, 1] < new_spatial_shape[1]) * (coords[:, 2] < new_spatial_shape[2])
pair_indices_in = pair_indices[0] # [k**3, N]
pair_indices_out = pair_indices[1] # [k**3, N]
if conv_valid_mask.dtype == torch.bool:
valid_position = torch.nonzero(conv_valid_mask).view(-1,)
value_mask = torch.isin(pair_indices_in, value_mask)
valid_position = valid_position[spatial_indices]
mask = torch.isin(pair_indices_out, valid_position)
mask = mask * value_mask
mask = ~mask
pair_indices_out[mask] = -1
cur_flops = 2 * (pair_indices_out > -1).sum() * in_channels * out_channels - pair_indices_out.shape[1]
return cur_flops
The way of getting the value_mask (SPRS):
out = self.combine_feature(x_im, x_nim, remove_repeat=True)
value_mask = None
if not self.training:
value_mask = (out.features.sum(-1)!=0)
value_mask = torch.nonzero(value_mask).view(-1,)
The code for calculating the flops of SPSS:
def calculate_spss_flops(x, indice_key, in_channels, out_channels, mask_position):
if mask_position.dtype == torch.bool:
mask_position = torch.nonzero(mask_position).view(-1,)
pair_indices = copy.deepcopy(x.indice_dict[indice_key].indice_pairs)
pair_indices_out = pair_indices[1] # [k**3, N]
mask = torch.isin(pair_indices_out, mask_position)
pair_indices_out[mask] = -1
cur_flops = 2 * (pair_indices_out > -1).sum() * in_channels * out_channels - pair_indices_out.shape[1]
return cur_flops
Thank you for your detail reply! I will close this issue.
For model trained by second_spss_ratio0.5_sprs_ratio0.5.yaml, I get different GFLOPs. Here is the code I used. For SpatialPrunedConvDownsample:
x_im, x_nim = self.gemerate_sparse_tensor(x, voxel_importance)
out = self.combine_feature(x_im, x_nim, remove_repeat=True)
value_mask = None
out = self.conv_block(out)
if not self.training:
value_mask = (out.features.sum(-1) != 0)
value_mask = torch.nonzero(value_mask).view(-1, )
batch_dict['backbone3d_flops'] += calculated_sprs_flops(out, self.indice_key, value_mask,
self.in_channels,
self.out_channels,
self.stride,
self.padding,
self.dilation,
self.kernel_size)
pair_indices = None
out = self.reset_spatial_shape(out, batch_dict, pair_indices, value_mask)
For SpatialPrunedSubmConvBlock
mask_position = self.get_importance_mask(x, voxel_importance)
x = x.replace_feature(x.features * voxel_importance)
x_nim = x
x_im = self.conv_block(x)
if not self.training:
batch_dict['backbone3d_flops'] += calculate_spss_flops(x_im,
self.indice_key,
self.in_channels,
self.out_channels,
mask_position)
out = self._combine_feature(x_im, x_nim, mask_position)
And I add the flops of conv_input and conv_out
batch_dict['backbone3d_flops'] += calculate_gemm_flops(x, 'subm1', 4, 16)
batch_dict['backbone3d_flops'] += calculate_gemm_flops(out, 'spconv_down2', 64, 128)
Finally, I get 2.74 GFLOPs on KITTI dataset.
Actually, We set the downsample pruning rate as {0.7, 0.5, 0.3} for KITTI when we submitted our paper, we have mentioned in the paper, so this would cause the different results of GFLOPs
All right. Thank you!
Very nice work. I want to known how to get the FLOPs of model, so I write some code based on the
calculate_gemm_flops
, can you tell me if my code is correct? Thank you!