zh320 / realtime-semantic-segmentation-pytorch

PyTorch implementation of over 30 realtime semantic segmentations models, e.g. BiSeNetv1, BiSeNetv2, CGNet, ContextNet, DABNet, DDRNet, EDANet, ENet, ERFNet, ESPNet, ESPNetv2, FastSCNN, ICNet, LEDNet, LinkNet, PP-LiteSeg, SegNet, ShelfNet, STDC, SwiftNet, and support knowledge distillation, distributed training, Optuna etc.
Apache License 2.0
133 stars 22 forks source link

RegSeg code bug and open questions #20

Closed felipesanmartin closed 2 weeks ago

felipesanmartin commented 1 month ago

I have some doubts about RegSeg code.

First, it doesn't run at it is, I have to change some things on ConvBNAct module, basically add groups to args:

# Regular convolution -> batchnorm -> activation
class ConvBNAct(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, 
                    bias=False, act_type='relu', groups=1, **kwargs):
        if isinstance(kernel_size, list) or isinstance(kernel_size, tuple):
            padding = ((kernel_size[0] - 1) // 2 * dilation, (kernel_size[1] - 1) // 2 * dilation)
        elif isinstance(kernel_size, int):    
            padding = (kernel_size - 1) // 2 * dilation

        super(ConvBNAct, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias, groups=groups),
            nn.BatchNorm2d(out_channels),
            Activation(act_type, **kwargs)
        )

Second, I have read the base code from Roland Gao, but I think your implementation doesn't fit with any experiment from him. Is it for a reason? The model works really good with your configuration, but I haven't tried with Roland architectures. Also, I think the last dilation ([5, 14]) is not used because it has stride=2 (ref). Is it ok?

Thanks a lot for share your really good work, I'm looking after any answer.

Best regards, Felipe

felipesanmartin commented 1 month ago

I can answer by myself: the config on this repo is the architecture found by DNAS (Table 2 in the paper). Stride should be 1 in the last D block according to Table 1 in the paper.

zh320 commented 2 weeks ago
  1. Thank you for pointing out the bug for ConvBNAct, I have already fixed it.

  2. Regarding of the implementation of RegSeg, I only followed the descriptions of the paper and not from its official repo. So there might be the differences between our implementations.

Also, I think the last dilation ([5, 14]) is not used because it has stride=2 (ref). Is it ok? Stride should be 1 in the last D block according to Table 1 in the paper.

You are right, the stride should be 1 in the last D block but here is the issue. According to Figure 4 of the paper, the D Block will always use skip connection which requires input channel should be equal to output channel. If they are not equal, one way I could do while also be compatible with the architectures of the paper is to use the structure like in panel c which requires stride=2. The other way one could do is to drop the skip connection but this is against Figure 4.

Best,

felipesanmartin commented 2 weeks ago

Yes, Roland Gao's implementation uses a Shortcut module that handle different input/output channels in DBlock, but it isn't detailed in the paper. Anyways, your implementation works really well. I pruned some blocks and it run in less than 1 [ms] using torch_tensorrt on a 3060!

Best.