sacmehta / ESPNet

ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation
https://sacmehta.github.io/ESPNet/
MIT License
541 stars 112 forks source link

Decoder model construction with 3 classes #13

Closed frk2 closed 6 years ago

frk2 commented 6 years ago

Hi! Im trying to use this model to segment drivable regions and only have 3 classes. Im able to train and use ESPNet-C just fine but when I try to use the ESPNet light decoder theres a problem in model construction in https://github.com/sacmehta/ESPNet/blob/master/train/Model.py#L183

the rounding down makes n go to 0 which is a problem.

Ive tried hacking this by making sure in the worst case n = n1 = 1 but clearly i dont fully understand the construction of the model since that doesn't work out.

Any tips would be appreciated. Thanks!

frk2 commented 6 years ago

Seems like the minimum number of classes I can train is 5. Hoping you can shine more light on that divide by 5 which seems to be critical. Thanks! :bowing_man:

sacmehta commented 6 years ago

This is because ESP block has five parallel braches and # of output channels are divided by 5 (see Line 183). When you have classes less than 5, this division returns 0.

See pages 4 and 5 for more details here

To make it work, you can do either of these: 1) Replace the DilatedParllelResidualBlockB with Normal CBR block.

self.combine_l2_l3 = nn.Sequential(BR(2*classes), CBR(2*classes, classes, 3, 1))#DilatedParllelResidualBlockB(2*classes , classes, add=False))

2) If you want to use the structure as it is, create a classes1 variable and set its value to classes. Then update the value of classes variable to 20, so that each branch in DilatedParallelResidualBlockB has a depth of 5. And update the classification layer too. A snipped is shown below:

def __init__(self, classes=20, p=2, q=3, encoderFile=None):
        '''
        :param classes: number of classes in the dataset. Default is 20 for the cityscapes
        :param p: depth multiplier
        :param q: depth multiplier
        :param encoderFile: pretrained encoder weights. Recall that we first trained the ESPNet-C and then attached the
                            RUM-based light weight decoder. See paper for more details.
        '''
        super().__init__()
        self.encoder = ESPNet_Encoder(classes, p, q)
#==============Change 1==========
        classes1 = classes
        classes = 20
#==============================
        if encoderFile != None:
            self.encoder.load_state_dict(torch.load(encoderFile))
            print('Encoder loaded!')
        # load the encoder modules
        self.modules = []
        for i, m in enumerate(self.encoder.children()):
            self.modules.append(m)

        # light-weight decoder
        self.level3_C = C(128 + 3, classes, 1, 1)
        self.br = nn.BatchNorm2d(classes, eps=1e-03)
        self.conv = CBR(19 + classes, classes, 3, 1)

        self.up_l3 = nn.Sequential(nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False))
        self.combine_l2_l3 = nn.Sequential(BR(2*classes), DilatedParllelResidualBlockB(2*classes , classes, add=False))

        self.up_l2 = nn.Sequential(nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False), BR(classes))
#==============Change 2==========
        self.classifier = nn.ConvTranspose2d(classes, classes1, 2, stride=2, padding=0, output_padding=0, bias=False)
frk2 commented 6 years ago

Thanks! Using my own ESPNet-C trained on 3 classes, I get:

Traceback (most recent call last):
  File "main.py", line 410, in <module>
    trainValidateSegmentation(parser.parse_args())
  File "main.py", line 186, in trainValidateSegmentation
    y = model.forward(x)
  File "/home/faraz/opencaret/ESPNet/train/Model.py", line 383, in forward
    output2_c = self.up_l3(self.br(self.modules[10](output2_cat))) #RUM
  File "/home/faraz/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/faraz/.local/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 49, in forward
    self.training or not self.track_running_stats, self.momentum, self.eps)
  File "/home/faraz/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 1194, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: running_mean should contain 3 elements not 20

with the above modification. Just the clarify the modification is to only change the classifier in the end right?

frk2 commented 6 years ago

Update: Changing to CBR works just fine. Thanks! :rocket:

Any drawbacks to using it this way?

sacmehta commented 6 years ago

With 3 classes, you are already working in low-dimensional space, so I don't think it will hurt.

Anyways, here is the updated code for you

class ESPNet(nn.Module):
    '''
    This class defines the ESPNet network
    '''

    def __init__(self, classes=20, p=2, q=3, encoderFile=None):
        '''
        :param classes: number of classes in the dataset. Default is 20 for the cityscapes
        :param p: depth multiplier
        :param q: depth multiplier
        :param encoderFile: pretrained encoder weights. Recall that we first trained the ESPNet-C and then attached the
                            RUM-based light weight decoder. See paper for more details.
        '''
        super().__init__()
        self.encoder = ESPNet_Encoder(classes, p, q)
        if encoderFile != None:
            self.encoder.load_state_dict(torch.load(encoderFile))
            print('Encoder loaded!')

        classes1 = classes
        classes = 20

        # load the encoder modules
        self.modules = []
        for i, m in enumerate(self.encoder.children()):
            self.modules.append(m)

        # light-weight decoder
        self.level3_C = C(128 + 3, classes, 1, 1)
        self.br = nn.BatchNorm2d(classes1, eps=1e-03)
        self.conv = CBR(19 + classes, classes, 3, 1)

        self.up_l3 = nn.Sequential(
            nn.ConvTranspose2d(classes1, classes, 2, stride=2, padding=0, output_padding=0, bias=False))
        self.combine_l2_l3 = nn.Sequential(BR(2 * classes),
                                           DilatedParllelResidualBlockB(2 * classes, classes, add=False))

        self.up_l2 = nn.Sequential(
            nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False), BR(classes))

        self.classifier = nn.ConvTranspose2d(classes, classes1, 2, stride=2, padding=0, output_padding=0, bias=False)

    def forward(self, input):
        '''
        :param input: RGB image
        :return: transformed feature map
        '''
        output0 = self.modules[0](input)
        inp1 = self.modules[1](input)
        inp2 = self.modules[2](input)

        output0_cat = self.modules[3](torch.cat([output0, inp1], 1))
        output1_0 = self.modules[4](output0_cat)  # down-sampled

        for i, layer in enumerate(self.modules[5]):
            if i == 0:
                output1 = layer(output1_0)
            else:
                output1 = layer(output1)

        output1_cat = self.modules[6](torch.cat([output1, output1_0, inp2], 1))

        output2_0 = self.modules[7](output1_cat)  # down-sampled
        for i, layer in enumerate(self.modules[8]):
            if i == 0:
                output2 = layer(output2_0)
            else:
                output2 = layer(output2)

        output2_cat = self.modules[9](torch.cat([output2_0, output2], 1))  # concatenate for feature map width expansion
        output2_c = self.up_l3(self.br(self.modules[10](output2_cat)))  # RUM

        output1_C = self.level3_C(output1_cat)  # project to C-dimensional space
        comb_l2_l3 = self.up_l2(self.combine_l2_l3(torch.cat([output1_C, output2_c], 1)))  # RUM

        concat_features = self.conv(torch.cat([comb_l2_l3, output0_cat], 1))

        classifier = self.classifier(concat_features)
        return classifier

Sorry I forget to update the batch norm (self.br) and upsampling layer (self.up_l3) parameters that takes ESPNet-C's output.

frk2 commented 6 years ago

Nice this works after I change the classifier to output classes1 as well. Thanks a lot!