Closed czarmanu closed 2 years ago
You can build a 3channel model (e.g., MSCG-Rx50_RGB
) as following code (modify /tools/model.py
)
def load_model(name='MSCG-Rx50', classes=7, node_size=(32,32)):
if name == 'MSCG-Rx50':
net = rx50_gcn_3head_4channel(out_channels=classes)
elif name == 'MSCG-Rx101':
net = rx101_gcn_3head_4channel(out_channels=classes)
elif name == 'MSCG-Rx50_RGB:
net = rx50_gcn_3head_3channel(out_channels=classes) # only take 3-channel as input
else:
print('not found the net')
return -1
return net
class rx50_gcn_3head_3channel(nn.Module):
def __init__(self, out_channels=7, pretrained=True,
nodes=(32, 32), dropout=0,
enhance_diag=True, aux_pred=True):
super(rx50_gcn_3head_4channel, self).__init__() # same with res_fdcs_v5
self.aux_pred = aux_pred
self.node_size = nodes
self.num_cluster = out_channels
resnet = se_resnext50_32x4d()
self.layer0, self.layer1, self.layer2, self.layer3, = \
resnet.layer0, resnet.layer1, resnet.layer2, resnet.layer3
self.graph_layers1 = GCN_Layer(1024, 128, bnorm=True, activation=nn.ReLU(True), dropout=dropout)
self.graph_layers2 = GCN_Layer(128, out_channels, bnorm=False, activation=None)
self.scg = SCG_block(in_ch=1024,
hidden_ch=out_channels,
node_size=nodes,
add_diag=enhance_diag,
dropout=dropout)
weight_xavier_init(self.graph_layers1, self.graph_layers2, self.scg)
def forward(self, x):
x_size = x.size()
gx = self.layer3(self.layer2(self.layer1(self.layer0(x))))
gx90 = gx.permute(0, 1, 3, 2)
gx180 = gx.flip(3)
B, C, H, W = gx.size()
A, gx, loss, z_hat = self.scg(gx)
gx, _ = self.graph_layers2(
self.graph_layers1((gx.reshape(B, -1, C), A))) # + gx.reshape(B, -1, C)
if self.aux_pred:
gx += z_hat
gx = gx.reshape(B, self.num_cluster, self.node_size[0], self.node_size[1])
A, gx90, loss2, z_hat = self.scg(gx90)
gx90, _ = self.graph_layers2(
self.graph_layers1((gx90.reshape(B, -1, C), A))) # + gx.reshape(B, -1, C)
if self.aux_pred:
gx90 += z_hat
gx90 = gx90.reshape(B, self.num_cluster, self.node_size[1], self.node_size[0])
gx90 = gx90.permute(0, 1, 3, 2)
gx += gx90
A, gx180, loss3, z_hat = self.scg(gx180)
gx180, _ = self.graph_layers2(
self.graph_layers1((gx180.reshape(B, -1, C), A))) # + gx.reshape(B, -1, C)
if self.aux_pred:
gx180 += z_hat
gx180 = gx180.reshape(B, self.num_cluster, self.node_size[0], self.node_size[1])
gx180 = gx180.flip(3)
gx += gx180
gx = F.interpolate(gx, (H, W), mode='bilinear', align_corners=False)
if self.training:
return F.interpolate(gx, x_size[2:], mode='bilinear', align_corners=False), loss + loss2 + loss3
else:
return F.interpolate(gx, x_size[2:], mode='bilinear', align_corners=False)
And then change train_args in train_R50.py
like
train_args = agriculture_configs(net_name='MSCG-Rx50_RGB',
data='Agriculture',
bands_list=['RGB'],
kf=0, k_folder=0,
note='only_use_RGB'
)
Perfect! Thanks
While testing:
Traceback (most recent call last):
File "/scratch/manu/MSCG-Net-master_selftrained/./tools/test_submission.py", line 241, in
in test_submission.py, I updated the line :
test_files = get_real_test_list(bands=['NIR','RGB']) to test_files = get_real_test_list(bands=['RGB'])
Still the error persists. Anymore updates needed ?
I updated the following line in configs_kf.py:
bands = ['NIR', 'RGB']
as
bands = ['RGB']
Also, updated in the ckpt.y file the following: ckpt4 = { 'net': 'MSCG-Rx50_RGB', 'data': 'Agriculture', 'bands': ['RGB'], 'nodes': (32,32), 'snapshot': 'ckpt/epoch_6_loss_0.50475_acc_0.93916_acc-cls_0.95160_mean-iu_0.87998_fwavacc_0.88707_f1_0.93597_lr_0.0000845121.pth' }
Could you please advice how to deactivate the NIR channel and use just RGB channels as input?
Thanks!