Serge-weihao / CCNet-Pure-Pytorch

Criss-Cross Attention (2d&3d) for Semantic Segmentation in pure Pytorch with a faster and more precise implementation.
MIT License
183 stars 21 forks source link

关于网络最后的return[x, x_dsn] #3

Open yearing1017 opened 4 years ago

yearing1017 commented 4 years ago

您好,感谢您的代码实现。有些疑问想请教您一下:

  1. 网络最后的return[x, x_dsn]是什么意思,返回一个list?在论文中,最后不该是一个cat吗?
  2. 网络最后的输出,如果不经过上采样的话,如何恢复原图大小?
yearing1017 commented 4 years ago

@Serge-weihao

Serge-weihao commented 4 years ago

x和x_dsn被分别求loss,然后对各自的loss加权求和得到最后的loss,所以不应该cat

yearing1017 commented 4 years ago

感谢回复,如果我只想在我的程序中使用CCNet网络,例如输入:[4,3,320,320],然后输出为[4,4,320,320],分割为4类,是不是应该修改网络forward代码:最后使用上采样得到输出? @Serge-weihao

yearing1017 commented 4 years ago

您好,我这样修改ccnet.py的forward代码,不分别求loss,对x_dsn和x进行融合,然后进行分割,只对这一个结果与label求loss,可以吗?

    def forward(self, x, labels=None):
        #print(111)
        size = (x.shape[2], x.shape[3])
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        #print(222)
        x = self.layer3(x)
        #print(333)
        x_dsn = self.dsn(x)
        #print(x_dsn.shape)
        x = self.layer4(x)
        #print(x.shape)
        x = self.head(x, self.recurrence)
        #print(x.shape)
        outs = torch.cat([x, x_dsn],1)
        print(outs.shape)
        outs = self.conv4(outs)
        outs = nn.Upsample(size, mode='bilinear', align_corners=True)(outs)
        #print(outs)
        return outs
Serge-weihao commented 4 years ago

deeplabv3plus就是类似的结构,不同level concat起来,你觉得合适就行