boschresearch / ALDM

Official implementation of "Adversarial Supervision Makes Layout-to-Image Diffusion Models Thrive" (ICLR 2024)
https://yumengli007.github.io/ALDM/
GNU Affero General Public License v3.0
52 stars 3 forks source link

got an unexpected keyword argument 'is_inference' #11

Closed MalekSamet closed 5 months ago

MalekSamet commented 5 months ago

pic_1

When I remove is_inference from:

class NewSegmentationModule(SegmentationModule): def forward(self, image, label, is_inference=None): segSize = (label.shape[-2], label.shape[-1]) pred = self.decoder(self.encoder(image, return_feature_maps=True), segSize=segSize, is_inference=is_inference) return pred

The error is gone but I get another error:

pic_2

YumengLi007 commented 5 months ago

Hi @MalekSamet , please replace the UperNet class in mit_semseg/model/models.py with the following:

Click to expland ``` class UPerNet(nn.Module): def __init__(self, num_class=150, fc_dim=4096, use_softmax=False, pool_scales=(1, 2, 3, 6), fpn_inplanes=(256, 512, 1024, 2048), fpn_dim=256): super(UPerNet, self).__init__() self.use_softmax = use_softmax # PPM Module self.ppm_pooling = [] self.ppm_conv = [] for scale in pool_scales: self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) self.ppm_conv.append(nn.Sequential( nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), BatchNorm2d(512), nn.ReLU(inplace=True) )) self.ppm_pooling = nn.ModuleList(self.ppm_pooling) self.ppm_conv = nn.ModuleList(self.ppm_conv) self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1) # FPN Module self.fpn_in = [] for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer self.fpn_in.append(nn.Sequential( nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), BatchNorm2d(fpn_dim), nn.ReLU(inplace=True) )) self.fpn_in = nn.ModuleList(self.fpn_in) self.fpn_out = [] for i in range(len(fpn_inplanes) - 1): # skip the top layer self.fpn_out.append(nn.Sequential( conv3x3_bn_relu(fpn_dim, fpn_dim, 1), )) self.fpn_out = nn.ModuleList(self.fpn_out) self.conv_last = nn.Sequential( conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1), nn.Conv2d(fpn_dim, num_class, kernel_size=1) ) def forward(self, conv_out, segSize=None, is_inference=None): conv5 = conv_out[-1] input_size = conv5.size() ppm_out = [conv5] for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): ppm_out.append(pool_conv(nn.functional.interpolate( pool_scale(conv5), (input_size[2], input_size[3]), mode='bilinear', align_corners=False))) ppm_out = torch.cat(ppm_out, 1) f = self.ppm_last_conv(ppm_out) fpn_feature_list = [f] for i in reversed(range(len(conv_out) - 1)): conv_x = conv_out[i] conv_x = self.fpn_in[i](conv_x) # lateral branch f = nn.functional.interpolate( f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch f = conv_x + f fpn_feature_list.append(self.fpn_out[i](f)) fpn_feature_list.reverse() # [P2 - P5] output_size = fpn_feature_list[0].size()[2:] fusion_list = [fpn_feature_list[0]] for i in range(1, len(fpn_feature_list)): fusion_list.append(nn.functional.interpolate( fpn_feature_list[i], output_size, mode='bilinear', align_corners=False)) fusion_out = torch.cat(fusion_list, 1) x = self.conv_last(fusion_out) if is_inference is not None: if is_inference: x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) x = nn.functional.softmax(x, dim=1) return x else: x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) #x = nn.functional.log_softmax(x, dim=1) return x else: if self.use_softmax: # is True during inference x = nn.functional.interpolate( x, size=segSize, mode='bilinear', align_corners=False) x = nn.functional.softmax(x, dim=1) return x x = nn.functional.log_softmax(x, dim=1) return x ```
MalekSamet commented 5 months ago

Still same issue

YumengLi007 commented 5 months ago

hmm that's weird, please make sure to modify the one in the conda environement you are using for running the code