samleoqh / MSCG-Net

Multi-view Self-Constructing Graph Convolutional Networks with Adaptive Class Weighting Loss for Semantic Segmentation
MIT License
67 stars 28 forks source link

IJRS 2021 paper code #20

Open czarmanu opened 2 years ago

czarmanu commented 2 years ago

Do you have a version of the code optimised for the IJRS 2021 paper?

Self-constructing graph neural networks to model long-range pixel dependencies for semantic segmentation of remote sensing images(https://www.tandfonline.com/doi/full/10.1080/01431161.2021.1936267?scroll=top&needAccess=true)

samleoqh commented 2 years ago

The model for IJRS2020 can be easily bulit based on /lib/net/scg_gcn.py , like as below, and the training pipeline for Vaihingen dataset almost same as DDCM-Net.

from lib.net.scg_gcn import *

class SCG_Net_R50(nn.Module):
    def __init__(self, out_channels=6, pretrained=True,
                 nodes=(28, 28), dropout=0,
                 enhance_diag=True, aux_pred=True):
        super(SCG_Net_R50, self).__init__()  

        self.aux_pred = aux_pred
        self.node_size = nodes
        self.num_cluster = out_channels

        resnet = models.resnet50()

        if pretrained:
            # resnet.load_state_dict(torch.load(res50_path))
            state_dict = load_state_dict_from_url(model_urls['resnet50'],
                                                  progress=True)
            resnet.load_state_dict(state_dict)

        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        self.layer1, self.layer2, self.layer3 = resnet.layer1, resnet.layer2, resnet.layer3

        self.graph_layers1 = GCN_Layer(1024, 128, bnorm=True, activation=nn.ReLU(True), dropout=dropout)
        self.graph_layers2 = GCN_Layer(128, out_channels, bnorm=False, activation=None)

        self.scg = SCG_block(in_ch=1024,
                           hidden_ch=out_channels,
                           node_size=nodes,
                           add_diag=enhance_diag,
                           dropout=dropout)

        weight_xavier_init(self.graph_layers1, self.graph_layers2, self.scg)

    def forward(self, x):
        x_size = x.size()
        # x = self.dec0(x)
        gx = self.layer3(self.layer2(self.layer1(self.layer0(x))))
        B, C, H, W = gx.size()

        A, gx, loss, z_hat, gamma = self.scg(gx)

        gx, A, _= self.graph_layers2(
                self.graph_layers1((gx.reshape(B, -1, C), A, False))) # + gx.reshape(B, -1, C)

        if self.aux_pred:
            gx += gamma * z_hat

        gx = gx.reshape(B, self.num_cluster,  self.node_size[0], self.node_size[1])

        gx = F.interpolate(gx, (H, W), mode='bilinear', align_corners=False)

        if self.training:
            return F.interpolate(gx, x_size[2:], mode='bilinear', align_corners=False), loss
        else:
            return F.interpolate(gx, x_size[2:], mode='bilinear', align_corners=False)