Lyken17 / pytorch-OpCounter

Count the MACs / FLOPs of your PyTorch model.
MIT License
4.85k stars 528 forks source link

different results when counting nn.BatchNorm1d in nn.Sequential #162

Open razIove opened 2 years ago

razIove commented 2 years ago
class test_model1(nn.Module):
    def __init__(self):
        super(test_model1, self).__init__()
        self.block = nn.Sequential(nn.BatchNorm1d(64))

    def forward(self, x):
        x = self.block(x)
        return x

class test_model2(nn.Module):
    def __init__(self):
        super(test_model2, self).__init__()
        self.bn = nn.BatchNorm1d(64)
        self.block = nn.Sequential(
            self.bn,
        )

    def forward(self, x):
        x = self.block(x)
        return x

if __name__ == "__main__":

    data = torch.rand(1, 64, 1)

    model = test_model1().eval()
    macs, params = profile(model, inputs=(data,))
    macs, params = clever_format([macs, params], "%.3f")
    print(f"model1: MACs {macs} Params {params}")

    model2 = test_model2().eval()
    macs, params = profile(model2, inputs=(data,))
    macs, params = clever_format([macs, params], "%.3f")
    print(f"model2: MACs {macs} Params {params}")

then get

[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>. [WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params. [WARN] Cannot find rule for <class 'main.test_model1'>. Treat it as zero Macs and zero Params. model1: MACs 128.000B Params 128.000B [INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>. [WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params. [WARN] Cannot find rule for <class 'main.test_model2'>. Treat it as zero Macs and zero Params. model2: MACs 512.000B Params 256.000B

Please check it.

HaoKang-Timmy commented 2 years ago

It seems that our counting macs has calculated twice, I will find ways to fix it. As for the macs counting, I will check it.