ChengpengChen / RepGhost

RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization
MIT License
168 stars 17 forks source link

参数量和Flops计算的问题 #10

Closed Linaom1214 closed 1 year ago

Linaom1214 commented 1 year ago

我测试了你们提供的参数量和Flops计算脚本,发现重参数化前后没有明显区别。但是权重文件明显变小了,可以更新一下计算脚本吗?

    from tools import cal_flops_params
    input = torch.randn(1, 3, 224, 224)
    print("[ Train Model ]")
    flops, params = cal_flops_params(train_model, input_size=input.shape)
    print("[ Infer Model ]")
    flops, params = cal_flops_params(infer_model, input_size=input.shape)

输出信息

[ Train Model ]
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
flops = 43317244.0, params = 2313968.0
flops = 43.32M, params = 2.31M
[ Infer Model ]
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
flops = 44165532.0, params = 2306528.0
flops = 44.17M, params = 2.31M

完整代码


def convert():
    args = parser.parse_args()

    m = importlib.import_module(f"model.{args.model.split('.')[0]}")
    train_model = getattr(m, args.model.split('.')[1])()
    train_model.eval()

    if os.path.isfile(args.load):
        print("=> loading checkpoint '{}'".format(args.load))
        checkpoint = torch.load(args.load, map_location='cpu')
        if args.ema_model and 'state_dict_ema' in checkpoint:
            checkpoint = checkpoint['state_dict_ema']
        else:
            checkpoint = checkpoint['state_dict']

        try:
            train_model.load_state_dict(checkpoint)
        except Exception as e:
            ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()}  # strip the names
            # print(ckpt.keys())
            train_model.load_state_dict(ckpt)
    else:
        print("=> no checkpoint found at '{}'".format(args.load))

    infer_model = repghost_model_convert(train_model, save_path=args.save)
    print("=> saved checkpoint to '{}'".format(args.save))

    if args.sanity_check:
        data = torch.randn(5, 3, 224, 224)
        out = train_model(data)
        out2 = infer_model(data)
        print('=> The diff is', ((out - out2) ** 2).sum())

    from tools import cal_flops_params
    input = torch.randn(1, 3, 224, 224)
    print("[ Train Model ]")
    flops, params = cal_flops_params(train_model, input_size=input.shape)
    print("[ Infer Model ]")
    flops, params = cal_flops_params(infer_model, input_size=input.shape)
ChengpengChen commented 1 year ago

你好!感谢你的建议!

需要声明的一点是,我们的计算脚本只计算conv和fc 的参数量和Flops(忽略BN,因为可以合并到conv),因此,RepGhostNet 重参数化前后(3x3dwconv 合并了两个BN,即conv后面和重参各一个),参数量和Flops变化不大。而权重文件包括了所以参数,因此重参数化后会明显变小,即少了被3x3dwconv合并的BN。

Linaom1214 commented 1 year ago

好的,谢谢回复。