facebookresearch / SparseConvNet

Submanifold sparse convolutional networks
https://github.com/facebookresearch/SparseConvNet
Other
2.04k stars 332 forks source link

Rewrite for convolution operation #241

Open CheungBH opened 1 year ago

CheungBH commented 1 year ago

Thanks for your great work. I want to use the inference of sparse conv operation, but the code doesn't provide such a function. Therefore, I am rewriting it like this. When only_forward=True, the input will be processed directly without the operation of ctx. However, I found there are nan using such a method. Do you have any ideas for solving it?

class ConvolutionFunction(Function):

@staticmethod
def forward(
        ctx,
        input_features,
        weight,
        bias,
        input_metadata,
        input_spatial_size,
        output_spatial_size,
        dimension,
        filter_size,
        filter_stride,
        only_forward=False):
    output_features = input_features.new()
    if only_forward:
        sparseconvnet.SCN.Convolution_updateOutput(
            input_spatial_size,
            output_spatial_size,
            filter_size,
            filter_stride,
            input_metadata,
            input_features,
            output_features,
            weight,
            bias)
        return output_features

    output_features = input_features.new()
    ctx.input_metadata = input_metadata
    ctx.dimension = dimension
    ctx.save_for_backward(
        input_features,
        input_spatial_size,
        weight,
        bias,
        output_spatial_size,
        filter_size,
        filter_stride)
    sparseconvnet.forward_pass_multiplyAdd_count +=\
        sparseconvnet.SCN.Convolution_updateOutput(
            input_spatial_size,
            output_spatial_size,
            filter_size,
            filter_stride,
            input_metadata,
            input_features,
            output_features,
            weight,
            bias)
    sparseconvnet.forward_pass_hidden_states += output_features.nelement()
    return output_features

@staticmethod
def backward(ctx, grad_output):
    input_features, input_spatial_size, weight, bias, output_spatial_size, filter_size, filter_stride = ctx.saved_tensors
    grad_input = grad_output.new()
    grad_weight = torch.zeros_like(weight)
    grad_bias = torch.zeros_like(bias)
    sparseconvnet.SCN.Convolution_backward(
        input_spatial_size,
        output_spatial_size,
        filter_size,
        filter_stride,
        ctx.input_metadata,
        input_features,
        grad_input,
        grad_output.contiguous(),
        weight,
        grad_weight,
        grad_bias)
    return grad_input, grad_weight, optionalTensorReturn(grad_bias), None, None, None, None, None, None