zhanghang1989 / PyTorch-Encoding

A CV toolkit for my papers.
https://hangzhang.org/PyTorch-Encoding/
MIT License
2.04k stars 452 forks source link

Apply Context Encoding to a Custom FCN #237

Open luistelmocosta opened 4 years ago

luistelmocosta commented 4 years ago

Hello, I am developing a simple FCN for texture segmentation and I would like to apply the context encoding module to see if it improves my model performance. However, after reading the documentation and checking several examples, I could not figure out how to apply this to my network.

This is my architecture:

class LuisNet_v3(nn.Module):
    def __init__(self, n_class=21):
        super(LuisNet_v3, self).__init__()

        # conv1
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=2, padding=100)
        self.relu1_1 = nn.ReLU(inplace=True)
        self.bn1_1 = nn.BatchNorm2d(64, eps=0.001)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=2, padding=1)
        self.relu1_2 = nn.ReLU(inplace=True)
        self.bn1_2 = nn.BatchNorm2d(64, eps=0.001)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/2

        # conv2
        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=2, padding=1)
        self.relu2_1 = nn.ReLU(inplace=True)
        self.bn2_1 = nn.BatchNorm2d(128, eps=0.001)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=2, padding=1)
        self.relu2_2 = nn.ReLU(inplace=True)
        self.bn2_2 = nn.BatchNorm2d(128, eps=0.001)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/4

        # conv3
        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=2, padding=1)
        self.relu3_1 = nn.ReLU(inplace=True)
        self.bn3_1 = nn.BatchNorm2d(256, eps=0.001)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=2, padding=1)
        self.relu3_2 = nn.ReLU(inplace=True)
        self.bn3_2 = nn.BatchNorm2d(256, eps=0.001)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=2, padding=1)
        self.relu3_3 = nn.ReLU(inplace=True)
        self.bn3_3 = nn.BatchNorm2d(256, eps=0.001)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/8

        # conv4
        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=2, padding=1)
        self.relu4_1 = nn.ReLU(inplace=True)
        self.bn4_1 = nn.BatchNorm2d(512, eps=0.001)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=2, padding=1)
        self.relu4_2 = nn.ReLU(inplace=True)
        self.bn4_2 = nn.BatchNorm2d(512, eps=0.001)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=2, padding=1)
        self.relu4_3 = nn.ReLU(inplace=True)
        self.bn4_3 = nn.BatchNorm2d(512, eps=0.001)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/16

        # conv5
        '''self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu5_1 = nn.ReLU(inplace=True)
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu5_2 = nn.ReLU(inplace=True)
        self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu5_3 = nn.ReLU(inplace=True)
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)'''  # 1/32

        # fc6
        self.fc6 = nn.Conv2d(512, 4096, 7)
        self.relu6 = nn.ReLU(inplace=True)
        self.drop6 = nn.Dropout2d()

        # fc7
        self.fc7 = nn.Conv2d(4096, 4096, 1)
        self.relu7 = nn.ReLU(inplace=True)
        self.drop7 = nn.Dropout2d()

        self.score_fr = nn.Conv2d(4096, n_class, 1)
        self.score_pool3 = nn.Conv2d(256, n_class, 1)
        self.score_pool4 = nn.Conv2d(512, n_class, 1)
        self.score_pool2 = nn.Conv2d(128, n_class, 1)
        self.score_pool1 = nn.Conv2d(64, n_class, 1)

        '''self.upscore2 = nn.ConvTranspose2d(n_class, n_class, 4, stride=2, bias=False)
        self.upscore4 = nn.ConvTranspose2d(n_class, n_class, 8, stride=2, bias=False)
        self.upscore8 = nn.ConvTranspose2d(n_class, n_class, 16, stride=2, bias=False)
        self.upscore_pool2 = nn.ConvTranspose2d(n_class, n_class, 4, stride=2, bias=False)
        self.upscore_pool4 = nn.ConvTranspose2d(n_class, n_class, 4, stride=2, bias=False)
        self.upscore_pool3 = nn.ConvTranspose2d(n_class, n_class, 8, stride=2, bias=False)
        self.upscore4 = nn.ConvTranspose2d(n_class, n_class, 8, stride=2, bias=False)'''

        self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
        self.upscore4 = nn.Upsample(scale_factor=4, mode='bilinear')
        self.upscore8 = nn.Upsample(scale_factor=8, mode='bilinear')
        self.upscore_pool2 = nn.Upsample(scale_factor=2, mode='bilinear')
        self.upscore_pool4 = nn.Upsample(scale_factor=2, mode='bilinear')
        self.upscore_pool3 = nn.Upsample(scale_factor=4, mode='bilinear')
        #self.upscore4 = nn.Upsample(scale_factor=4, mode='bilinear')

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.zero_()
                if m.bias is not None:
                    m.bias.data.zero_()
            if isinstance(m, nn.ConvTranspose2d):
                assert m.kernel_size[0] == m.kernel_size[1]
                initial_weight = get_upsampling_weight(
                    m.in_channels, m.out_channels, m.kernel_size[0])
                m.weight.data.copy_(initial_weight)

    def forward(self, x):

        #POOL 1 

        #Conv2d_1a_3x3
        h = x
        h = self.conv1_1(h)
        h = self.relu1_1(h)
        h = self.bn1_1(h)

        #Conv2d_1b_3x3
        h = self.conv1_2(h)
        h = self.relu1_2(h)
        h = self.bn1_2(h)
        h = self.pool1(h)

        pool1 = h

        #POOL 2 

        #Conv2d_2a_3x3
        h = self.conv2_1(h)
        h = self.relu2_1(h)
        h = self.bn2_1(h)

        #Conv2d_2b_3x3
        h = self.conv2_2(h)
        h = self.relu2_2(h)
        h = self.bn2_2(h)
        h = self.pool2(h)

        pool2 = h

        #POOL 3

        #Conv2d_3a_3x3
        h = self.conv3_1(h)
        h = self.relu3_1(h)
        h = self.bn3_1(h)

        #Conv2d_3b_3x3
        h = self.conv3_2(h)
        h = self.relu3_2(h)
        h = self.bn3_2(h)

        #Conv2d_3c_3x3
        h = self.conv3_3(h)
        h = self.relu3_3(h)
        h = self.bn3_3(h)
        h = self.pool3(h)

        pool3 = h  # 1/8

        #POOL 4

        #Conv2d_4a_3x3
        h = self.conv4_1(h)
        h = self.relu4_1(h)
        h = self.bn4_1(h)

        #Conv2d_4b_3x3
        h = self.conv4_2(h)
        h = self.relu4_2(h)
        h = self.bn4_2(h)

        #Conv2d_4c_3x3
        h = self.conv4_3(h)
        h = self.relu4_3(h)
        h = self.bn4_3(h)
        h = self.pool4(h)

        pool4 = h  # 1/16

        #FC6 
        h = self.fc6(h)
        h = self.relu6(h)
        h = self.drop6(h)

        #FC7 
        h = self.fc7(h)
        h = self.relu7(h)
        h = self.drop7(h)

        h = self.score_fr(h) # final layer
        h = self.upscore2(h) # upsampling 2x 
        upscore4 = h  # 1/16

        #print("SHAPE9: ", h.shape)
        #h = self.score_pool4(pool4)

        h = self.score_pool3(pool3)
        #h = self.upscore_pool3(h)
        h = h[:, :, 5:5 + upscore4.size()[2], 5:5 + upscore4.size()[3]]
        score_pool3c = h  # 1/16
        #print("SHAPE10: ", upscore4.shape, score_pool3c.shape)
        h = upscore4 + score_pool3c  # 1/16

        h = self.upscore_pool4(h)
        upscore_pool2 = h  # 1/8
        #print("SHAPE11: ", h.shape)
        h = self.score_pool2(pool2)
        #h = self.upscore_pool2(h)
        h = h[:, :, 8:8 + upscore_pool2.size()[2], 8:8 + upscore_pool2.size()[3]]
        score_pool2c = h  # 1/8
        #print("SHAPE12: ", h.shape)
        h = upscore_pool2 + score_pool2c  # 1/8

        h = self.upscore_pool4(h)
        upscore_pool4 = h
        h = self.score_pool1(pool1)
        h = h[:, :, 13:13 + upscore_pool4.size()[2], 13:13 + upscore_pool4.size()[3]]
        score_pool1c = h
        #print("FINAL SHAPE", score_pool1c.shape, upscore_pool4.shape)
        h = score_pool1c + upscore_pool4

        #print("here")

        h = self.upscore2(h)
        h = h[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3]].contiguous()
        #print("SHAPE: ", h.shape)
        return h

    def predict(self, x):
        with torch.no_grad():
            x = self.forward(x)

        return x

Can you please tell me if there is any chance to apply the module to my custom network? I am really looking forward to this.

Kind regards and thank you for your amazing work

qiulesun commented 4 years ago

@luistelmocosta You can drop all FC layers and append Encoding module to the end of last conv. layer.

luistelmocosta commented 4 years ago

Well, but I would still need a FC after the Encoding Module, right?

qiulesun commented 4 years ago

Yes, one FC is needed for prediction.

luistelmocosta commented 4 years ago

What you suggest is to remove my FC6 layer, add the Encoding Module and then plug the FC7 layer for prediction, right?