roytseng-tw / Detectron.pytorch

A pytorch implementation of Detectron. Both training from scratch and inferring directly from pretrained Detectron weights are available.
MIT License
2.82k stars 567 forks source link

Add custom module to object detection pipeline and train jointly #214

Open ashnair1 opened 5 years ago

ashnair1 commented 5 years ago

I wanted to implement the non local module in the mask rcnn pipeline. I've already written it as a standalone module but I'm not sure how to include it in the detectron codebase so that I can jointly train the module and the model. Could someone give me a clue as to how to include it? I simply need to train an additional 4 convolutional layers (theta, phi, g, conv) but I'm unclear as to how to do it.

Here's the non local module:

Here's the non local module:

class NonLocalBlock(nn.Module):
    def __init__(self,X):
        super(NonLocalBlock, self).__init__()

        channels = X.shape[1] # Torch tensor => [batch_size, number_of_kernels, w, h].

        self.conv = nn.Conv2d(in_channels=channels,out_channels=channels,kernel_size=(1,1))

        # Embeddings

        self.theta = nn.Conv2d(in_channels=channels,out_channels=channels,kernel_size=(1,1))
        self.phi = nn.Conv2d(in_channels=channels,out_channels=channels,kernel_size=(1,1))
        self.phi_pool = nn.MaxPool2d(2, 2)
        self.g = nn.Conv2d(in_channels=channels,out_channels=channels,kernel_size=(1,1))
        self.g_pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        # Assume input is [N, F, H, W]
        theta = F.relu(self.theta(x))               # Shape = [N, F, H, W]
        phi = F.relu(self.phi(x))                   # Shape = [N, F, H, W]
        phi = self.phi_pool(phi)                    # Shape = [N, F, H/2, W/2]
        g = F.relu(self.g(x))                       # Shape = [N, F, H, W]
        g = self.g_pool(g)                          # Shape = [N, F, H/2, W/2]

        # Reshape theta, phi and g 
        theta = theta.permute(0,2,3,1)              # Shape = [N, H, W, F]
        theta = theta.reshape(-1,theta.shape[-1])   # Shape = [NHW, F]
        phi = phi.permute(1,2,3,0)                  # Shape = [F, H/2, W/2, N] 
        phi = phi.reshape(phi.shape[0],-1)          # Shape = [F, N*H/2*W/2] 
        g = g.permute(0,2,3,1)                      # Shape = [N, H/2, W/2, F]
        g = g.reshape(-1,g.shape[-1])               # Shape = [N*H/2*W/2, F]

        # Matrix Multiplication 1
        prod = torch.matmul(theta,phi)              # Shape = [NHW, N*H/2*W/2]          
        softmax = nn.Softmax(dim=0)
        prod = softmax(prod)

        # Matrix Multiplication 2
        prod = torch.matmul(prod,g)                 # Shape = [NHW, F]

        prod = prod.reshape(x.shape[0],             # Shape = [N, H, W, F]

        prod = prod.permute(0,3,1,2)

        out = F.relu(self.conv(prod))

        #out = out.permute(0,1,2,3)                # Shape = [N, F, H, W]

        assert prod.shape == x.shape
        out = out + x

        return out

nn_local = NonLocalBlock(layer).to('cuda') # Send model to device
refined_layer = nn_local.forward(layer)