LiheYoung / UniMatch

[CVPR 2023] Revisiting Weak-to-Strong Consistency in Semi-Supervised Semantic Segmentation
https://arxiv.org/abs/2208.09910
MIT License
478 stars 60 forks source link

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1]) #8

Closed StevenLu1204 closed 2 years ago

StevenLu1204 commented 2 years ago

I was training on Cityscapes with the same config setup but only with 2 GPUs. I got this error which seems to suggest that I could not train the model with a batch size of 1 per GPU due to the batchnorm layer. However, I saw you successfully trained with batchsize=1 in your cityscapes training log. I wonder how this could be fixed, thank you!

LiheYoung commented 2 years ago

Did you use the supervised or fixmatch algorithm, instead of our unimatch? The unimatch should not encounter such a problem, because there are two concatenated strongly augmented images for each unlabeled sample. And if you used the former ones, you may solve this problem by updating the PyTorch version to 1.8.1.

StevenLu1204 commented 2 years ago

My PyTorch version is 1.8.1+cu111, the problem remains. Do you mind sharing your log on training fixmatch.py ?

LiheYoung commented 2 years ago

You can change the ASPPPooling from https://github.com/LiheYoung/UniMatch/blob/7292bace1bda8f64ca05ae42ca8c0eef045aa392/model/semseg/deeplabv3plus.py#L80-L91 to

class ASPPPooling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__()
        self.gap = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                 nn.Conv2d(in_channels, out_channels, 1, bias=False))
        self.post = nn.Sequential(nn.BatchNorm2d(out_channels),
                                  nn.ReLU(True))

    def forward(self, x):
        h, w = x.shape[-2:]
        pool = self.gap(x)
        return self.post(F.interpolate(pool, (h, w), mode="bilinear", align_corners=True))
StevenLu1204 commented 2 years ago

Thanks for the update. It works fine now. Looks like it's the problem of interpolation and BatchNorm2d?

LiheYoung commented 2 years ago

It is because the BN can not calculate the variance when the spatial size is 1x1.