Tramac / awesome-semantic-segmentation-pytorch

Semantic Segmentation on PyTorch (include FCN, PSPNet, Deeplabv3, Deeplabv3+, DANet, DenseASPP, BiSeNet, EncNet, DUNet, ICNet, ENet, OCNet, CCNet, PSANet, CGNet, ESPNet, LEDNet, DFANet)
Apache License 2.0
2.85k stars 582 forks source link

deeplabv3_plus-error:norm_kwargs #91

Closed yangninghua closed 4 years ago

yangninghua commented 4 years ago
Traceback (most recent call last):
  File "/home/spple/paddle/DeepGlint/deepglint-adv/semantic-segmentation/scripts/train.py", line 352, in <module>
    trainer = Trainer(args)
  File "/home/spple/paddle/DeepGlint/deepglint-adv/semantic-segmentation/scripts/train.py", line 171, in __init__
    aux=args.aux, norm_layer=BatchNorm2d).to(self.device)
  File "/home/spple/paddle/DeepGlint/deepglint-adv/semantic-segmentation/core/models/model_zoo.py", line 122, in get_segmentation_model
    return models[model](**kwargs)
  File "/home/spple/paddle/DeepGlint/deepglint-adv/semantic-segmentation/core/models/deeplabv3_plus.py", line 127, in get_deeplabv3_plus
    model = DeepLabV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
  File "/home/spple/paddle/DeepGlint/deepglint-adv/semantic-segmentation/core/models/deeplabv3_plus.py", line 41, in __init__
    self.head = _DeepLabHead(nclass, **kwargs)
  File "/home/spple/paddle/DeepGlint/deepglint-adv/semantic-segmentation/core/models/deeplabv3_plus.py", line 100, in __init__
    self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, **kwargs)
TypeError: __init__() missing 1 required positional argument: 'norm_kwargs'
yangninghua commented 4 years ago
class _ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates, norm_layer, norm_kwargs, **kwargs):
        super(_ASPP, self).__init__()
        out_channels = 256
        self.b0 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
            nn.ReLU(True)
        )
yangninghua commented 4 years ago

def init(self, in_channels, atrous_rates, norm_layer, norm_kwargs=None, **kwargs):