charlesCXK / TorchSemiSeg

[CVPR 2021] CPS: Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision
MIT License
504 stars 74 forks source link

How can I disable BatchNormalization layer? #30

Closed JihwanEom closed 3 years ago

JihwanEom commented 3 years ago

Hi again, I have more question about BN layer.

I want to disable batchnormalization layer, so I tried to change model status from model.train() to model.eval() in train.py. (And also tried replacing BatchNorm layer with Identity layer but same error occurred)

The train script occurred this error: """ ValueError: Expected input batch_size (21) to match target batch_size (2). """

How can I disable the BN layer? Thank you!

charlesCXK commented 3 years ago

Hi, this seems to work well:

class MyBN(SyncBatchNorm):    
    def forward(self, x):
        return x

The class could also inherit from nn.BatchNorm2d. You could use this to replace the BN function in the code, e.g.:

    if engine.distributed:
        BatchNorm2d = MyBN

in the train.py.

However, I want to note that, if you disable BN in the backbone the effect of the pretrained weight will be very pool, since it was trained with BN.

JihwanEom commented 3 years ago

Thank you for your kind reply!

I'm so sorry for separating the issue, then can I ask how to freeze pretrained bn layer?

charlesCXK commented 3 years ago

First, the BN in the backbone should not be replaced by MyBN. You could pass the right BN function to self.backbone in network.py. Second, about how to freeze the BN in the backbone. This link provides some solutions and may help you.

JihwanEom commented 3 years ago

Noted. Thank you!

JihwanEom commented 3 years ago

The link works well, I'll close this issue. Thank you for your kind help.