JaveyWang / Pyramid-Attention-Networks-pytorch

Implementation of Pyramid Attention Networks for Semantic Segmentation.
GNU General Public License v3.0
235 stars 55 forks source link

FPA tensor size unmatch #10

Open sparkfax opened 4 years ago

sparkfax commented 4 years ago

I run another dataset, image size(3, 320, 640) and get this error:

in FPA forward(self, x) 131 print(x2_2.shape) 132 print(x3_upsample.shape) --> 133 x2_merge = self.relu(x2_2 + x3_upsample) 134 x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge))) 135 x1_merge = self.relu(x1_2 + x2_upsample) RuntimeError: The size of tensor a (5) must match the size of tensor b (6) at non-singleton dimension 2 print: FPA channels=2048 # Branch 1 x1_1 = self.conv7x7_1(x) x1_1 = self.bn1_1(x1_1) x1_1 = self.relu(x1_1) #([16, 512, 10, 20]) x1_2 = self.conv7x7_2(x1_1) x1_2 = self.bn1_2(x1_2) #([16, 512, 10, 20]) # Branch 2 x2_1 = self.conv5x5_1(x1_1) x2_1 = self.bn2_1(x2_1) x2_1 = self.relu(x2_1) #([16, 512, 5, 10]) x2_2 = self.conv5x5_2(x2_1) x2_2 = self.bn2_2(x2_2) **#([16, 512, 5, 10])** # Branch 3 x3_1 = self.conv3x3_1(x2_1) x3_1 = self.bn3_1(x3_1) x3_1 = self.relu(x3_1) #([16, 512, 3, 5]) x3_2 = self.conv3x3_2(x3_1) x3_2 = self.bn3_2(x3_2) #([16, 512, 3, 5]) # Merge branch 1 and 2, x3_upsample: #**([16, 512, 6, 10])** x3_upsample = self.relu(self.bn_upsample_3(self.conv_upsample_3(x3_2))) x2_merge = self.relu(x2_2 + x3_upsample) x2_upsample = self.relu(self.bn_upsample_2(self.conv_upsample_2(x2_merge))) x1_merge = self.relu(x1_2 + x2_upsample) class Attention_Model(nn.Module):  def __init__(self, in_features=256, num_class=4):   super(Attention_Model, self).__init__()   self.convnet = ResNet50(pretrained=True)   self.pan = PAN(self.convnet.blocks[::-1])   self.mask_classifier = Mask_Classifier(in_features=256, num_class=(num_class+1))   def forward(self, imgs):   fms_blob, z = self.convnet(imgs)   out_ss = self.pan(fms_blob[::-1])   mask_pred = self.mask_classifier(out_ss)   return mask_pred PS: if image resize to (3,512,512), the error not show