NVIDIA / semantic-segmentation

Nvidia Semantic Segmentation monorepo
BSD 3-Clause "New" or "Revised" License
1.78k stars 388 forks source link

num_classes error with wider_resnet38 #72

Open kaushikb258 opened 4 years ago

kaushikb258 commented 4 years ago

I am trying to use wider_resnet trunk by setting the following in scripts/train_cityscapes.yml: snapshot: "ASSETS_PATH/seg_weights/wider_resnet38.pth.tar", arch: wider_resnet.wider_resnet38_a2,

But get the following error:

File "/home/kb/semantic-segmentation/network/init.py", line 53, in get_model net = net_func(num_classes=num_classes, criterion=criterion) TypeError: init() got an unexpected keyword argument 'num_classes'

ajtao commented 4 years ago

The problem is that the arch you picked is not a segmentation architecture. Please try arch: DeepV3PlusW38, which is a deeplab V3+ architecture that uses wider resnet32 as the trunk. Another architecture is mscale.DeepV3W38, which adds multi-scale attention.

kaushikb258 commented 4 years ago

OK, now I have this in scripts/train_cityscapes.yml: snapshot: "ASSETS_PATH/seg_weights/wider_resnet38.pth.tar", arch: mscale.DeepV3W38,

But get the error: raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])

The error originates from: File "/home/kb/semantic-segmentation/network/mscale.py", line 299, in _fwd aspp = self.aspp(final_features)

ajtao commented 4 years ago

I see what the problem is.

The problem you're seeing is that wider_resnet38.pth.tar is not a pretrained model for mscale.DeepV3W38. So the combination of that snapshot and that arch doesn't work. We don't supply a pretrained model for semantic segmentation for this architecture. This repo provides pretrained models for HRNet-based architectures. You can train your own deepv3 model, however, using the provided recipes.

The final thing is that wider_resnet38.pth.tar are IMGNET-pretrained weights for the wider_resnet38 trunk. These weights are automatically loaded by default whenever a wider_resnet model trunk is specified, unless you specify a snapshot, which will override those weights.

kaushikb258 commented 4 years ago

Thanks!