samleoqh / MSCG-Net

Multi-view Self-Constructing Graph Convolutional Networks with Adaptive Class Weighting Loss for Semantic Segmentation
MIT License
68 stars 28 forks source link

Deactivating the NIR channel #22

Closed czarmanu closed 2 years ago

czarmanu commented 2 years ago

Could you please advice how to deactivate the NIR channel and use just RGB channels as input?

Thanks!

samleoqh commented 2 years ago

You can build a 3channel model (e.g., MSCG-Rx50_RGB) as following code (modify /tools/model.py)

def load_model(name='MSCG-Rx50', classes=7, node_size=(32,32)):
    if name == 'MSCG-Rx50':
        net = rx50_gcn_3head_4channel(out_channels=classes)
    elif name == 'MSCG-Rx101':
        net = rx101_gcn_3head_4channel(out_channels=classes)
    elif name == 'MSCG-Rx50_RGB:
        net = rx50_gcn_3head_3channel(out_channels=classes)      # only take 3-channel as input 
    else:
        print('not found the net')
        return -1

    return net

class rx50_gcn_3head_3channel(nn.Module):
    def __init__(self, out_channels=7, pretrained=True,
                 nodes=(32, 32), dropout=0,
                 enhance_diag=True, aux_pred=True):
        super(rx50_gcn_3head_4channel, self).__init__()  # same with  res_fdcs_v5

        self.aux_pred = aux_pred
        self.node_size = nodes
        self.num_cluster = out_channels

        resnet = se_resnext50_32x4d()
        self.layer0, self.layer1, self.layer2, self.layer3, = \
            resnet.layer0, resnet.layer1, resnet.layer2, resnet.layer3

        self.graph_layers1 = GCN_Layer(1024, 128, bnorm=True, activation=nn.ReLU(True), dropout=dropout)

        self.graph_layers2 = GCN_Layer(128, out_channels, bnorm=False, activation=None)

        self.scg = SCG_block(in_ch=1024,
                             hidden_ch=out_channels,
                             node_size=nodes,
                             add_diag=enhance_diag,
                             dropout=dropout)

        weight_xavier_init(self.graph_layers1, self.graph_layers2, self.scg)

    def forward(self, x):
        x_size = x.size()

        gx = self.layer3(self.layer2(self.layer1(self.layer0(x))))
        gx90 = gx.permute(0, 1, 3, 2)
        gx180 = gx.flip(3)
        B, C, H, W = gx.size()

        A, gx, loss, z_hat = self.scg(gx)
        gx, _ = self.graph_layers2(
            self.graph_layers1((gx.reshape(B, -1, C), A)))  # + gx.reshape(B, -1, C)
        if self.aux_pred:
            gx += z_hat
        gx = gx.reshape(B, self.num_cluster, self.node_size[0], self.node_size[1])

        A, gx90, loss2, z_hat = self.scg(gx90)
        gx90, _ = self.graph_layers2(
            self.graph_layers1((gx90.reshape(B, -1, C), A)))  # + gx.reshape(B, -1, C)
        if self.aux_pred:
            gx90 += z_hat
        gx90 = gx90.reshape(B, self.num_cluster, self.node_size[1], self.node_size[0])
        gx90 = gx90.permute(0, 1, 3, 2)
        gx += gx90

        A, gx180, loss3, z_hat = self.scg(gx180)
        gx180, _ = self.graph_layers2(
            self.graph_layers1((gx180.reshape(B, -1, C), A)))  # + gx.reshape(B, -1, C)
        if self.aux_pred:
            gx180 += z_hat
        gx180 = gx180.reshape(B, self.num_cluster, self.node_size[0], self.node_size[1])
        gx180 = gx180.flip(3)
        gx += gx180

        gx = F.interpolate(gx, (H, W), mode='bilinear', align_corners=False)

        if self.training:
            return F.interpolate(gx, x_size[2:], mode='bilinear', align_corners=False), loss + loss2 + loss3
        else:
            return F.interpolate(gx, x_size[2:], mode='bilinear', align_corners=False)

And then change train_args in train_R50.py like

train_args = agriculture_configs(net_name='MSCG-Rx50_RGB',
                                 data='Agriculture',
                                 bands_list=['RGB'],
                                 kf=0, k_folder=0,
                                 note='only_use_RGB'
                                 )
czarmanu commented 2 years ago

Perfect! Thanks

czarmanu commented 2 years ago

While testing:

Traceback (most recent call last): File "/scratch/manu/MSCG-Net-master_selftrained/./tools/test_submission.py", line 241, in main() File "/scratch/manu/MSCG-Net-master_selftrained/./tools/test_submission.py", line 30, in main net4 = get_net(ckpt4) File "/scratch/manu/MSCG-Net-master_selftrained/tools/ckpt.py", line 56, in get_net net.load_state_dict(torch.load(ckpt['snapshot'])) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for rx50_gcn_3head_4channel: Missing key(s) in state_dict: "layer0.0.weight", "layer0.1.weight", "layer0.1.bias", "layer0.1.running_mean", "layer0.1.running_var", "conv0.weight". Unexpected key(s) in state_dict: "layer0.conv1.weight", "layer0.bn1.weight", "layer0.bn1.bias", "layer0.bn1.running_mean", "layer0.bn1.running_var", "layer0.bn1.num_batches_tracked".

czarmanu commented 2 years ago

in test_submission.py, I updated the line :

test_files = get_real_test_list(bands=['NIR','RGB']) to test_files = get_real_test_list(bands=['RGB'])

Still the error persists. Anymore updates needed ?

czarmanu commented 2 years ago

I updated the following line in configs_kf.py:

bands = ['NIR', 'RGB']

as

bands = ['RGB']

czarmanu commented 2 years ago

Also, updated in the ckpt.y file the following: ckpt4 = { 'net': 'MSCG-Rx50_RGB', 'data': 'Agriculture', 'bands': ['RGB'], 'nodes': (32,32), 'snapshot': 'ckpt/epoch_6_loss_0.50475_acc_0.93916_acc-cls_0.95160_mean-iu_0.87998_fwavacc_0.88707_f1_0.93597_lr_0.0000845121.pth' }