NVIDIA / MinkowskiEngine

Minkowski Engine is an auto-diff neural network library for high-dimensional sparse tensors
https://nvidia.github.io/MinkowskiEngine
Other
2.43k stars 360 forks source link

I wonder if there are corresponding implements of PyTorch: view( ), torch.bmm and torch.nn.Parameter() in MinkowskiEngine #558

Open d289760860 opened 1 year ago

d289760860 commented 1 year ago

I am trying to add an attention module into resnet, and the original torch version code is like this:

class NonLocalModule(nn.Module):
    def __init__(self, C, latent= 8):
        super(NonLocalModule, self).__init__()
        self.inputChannel = C
        self.latentChannel = C // latent

        self.bn1 = nn.BatchNorm1d(C//latent)
        self.bn2 = nn.BatchNorm1d(C//latent)
        self.bn3 = nn.BatchNorm1d(C//latent)
        self.bn4 = nn.BatchNorm1d(C)

        self.cov1 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
                                self.bn1,
                                nn.ReLU())
        self.cov2 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
                                self.bn2,
                                nn.ReLU())
        self.cov3 = nn.Sequential(nn.Conv1d(in_channels=C, out_channels=C//latent, kernel_size=1, bias=False),
                                self.bn3,
                                nn.ReLU())
        self.out_conv = nn.Sequential(nn.Conv1d(in_channels=C//latent, out_channels=C, kernel_size=1, bias=False),
                                self.bn4,
                                nn.ReLU())

        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        b, c, n = x.shape

        out1 = self.cov1(x).view(b, -1, n).permute(0, 2, 1) #b,n,c/latent
        out2 = self.cov2(x).view(b, -1, n) #b, c/latent, n

        attention_matrix = self.softmax(torch.bmm(out1, out2)) # b,n,n

        out3 = self.cov3(x).view(b, -1, n) # b,c/latent,n

        attention = torch.bmm(out3, attention_matrix.permute(0, 2, 1)) # b,c/latent,n

        out = self.out_conv(attention) #b,c,n

        return self.gamma*out + x

nn.BatchNorm1d, nn.Conv1d and nn.ReLU have their corresponding implements in MinkowskiEngine. However, I have trouble when I am looking for the corresponding implements of PyTorch: view( ), torch.bmm and torch.nn.Parameter(). How should I transform the code to MinkowskiEngine version?

d289760860 commented 1 year ago

I have searched the document and failed to find any corresponding function.