microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14.02k stars 1.81k forks source link

Has not supported replacing the module: `InstanceNorm2d` #4387

Open PanJinquan opened 2 years ago

PanJinquan commented 2 years ago

Describe the issue:

不支持Pytorch的nn.InstanceNorm2d image

下面定义一个简单使用了nn.InstanceNorm2d的SimpleModel模型, use_inorm=False可以正常pruning, 但use_inorm=True,出现:Has not supported replacing the module: InstanceNorm2d的错误

Example:

# -*-coding: utf-8 -*-
import os
import copy
import torch
import torch.nn as nn
import torch.onnx
import torch.nn.functional as F
from nni.compression.pytorch.utils.counter import count_flops_params
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.algorithms.compression.pytorch import pruning
from nni.compression.pytorch import apply_compression_results

def model_pruning(model: nn.Module,
                  input_size=[1, 3, 128, 128],
                  sparsity=0.2,
                  prune_mod="FPGM",
                  output_prune="pruning_output",
                  mask_file="",
                  dependency_aware=True,
                  device="cpu",
                  verbose=False,
                  **kwargs):
    info = ""
    model = model.to(device)
    if not os.path.exists(output_prune): os.makedirs(output_prune)
    prune_file = os.path.join(output_prune, 'pruned_naive_{}filter.pth'.format(prune_mod))
    onnx_file = os.path.join(output_prune, 'pruned_naive_{}filter.onnx'.format(prune_mod))
    mask_file = os.path.join(output_prune, 'mask_naive_{}filter.pth'.format(prune_mod)) if not mask_file else mask_file
    dummy_input = torch.randn(input_size).to(device)
    # 原始模型的计算量和参数量
    flops, params, _ = count_flops_params(model, dummy_input, verbose=verbose)
    info += f"origin-Model FLOPs {flops / 1e6:.2f}M, Params {params / 1e6:.2f}M\n"
    # 模型剪枝,会生成mask文件(mask_naive_l1filter.pth)
    if prune_mod.lower() == "Level".lower():
        config = [{'sparsity': sparsity, 'op_types': ['Conv2d']}]
        pruner = pruning.LevelPruner(model, config)
    elif prune_mod.lower() == "L1".lower():
        # op_types : Only Conv2d is supported in L1FilterPruner.
        # config = [{'sparsity': sparsity, 'op_types': ['Conv2d'], "exclude": False}]
        config = [{'sparsity': sparsity, 'op_types': ['Conv2d']}]
        pruner = pruning.L1FilterPruner(model, config, dependency_aware, dummy_input=dummy_input)
    elif prune_mod.lower() == "L2".lower():
        # op_types : Only Conv2d is supported in L2FilterPruner.
        config = [{'sparsity': sparsity, 'op_types': ['Conv2d']}]
        pruner = pruning.L2FilterPruner(model, config, dependency_aware, dummy_input=dummy_input)
    elif prune_mod.lower() == "FPGM".lower():
        # op_types : Only Conv2d is supported in FPGM Pruner
        config = [{'sparsity': sparsity, 'op_types': ['Conv2d']}]
        pruner = pruning.FPGMPruner(model, config, dependency_aware, dummy_input=dummy_input)
    elif prune_mod.lower() == "Slim".lower():
        config = [{'sparsity': sparsity, 'op_types': ['BatchNorm2d']}]
        pruner = pruning.ActivationMeanRankFilterPruner()
    else:
        raise Exception("Error prune_mod:{}".format(prune_mod))
        # compress the model, the mask will be updated.
    pruner.compress()
    # pruner.get_pruned_weights()
    # use a dummy input to apply the sparsify.
    out = model(dummy_input)
    # 剪枝后模型的计算量和参数量
    flops, params, _ = count_flops_params(model, dummy_input, verbose=verbose)
    info += f"pruner-Model FLOPs {flops / 1e6:.2f}M, Params {params / 1e6:.2f}M\n"
    # export the sparsified and mask model
    pruner.export_model(model_path=prune_file, mask_path=mask_file,
                        onnx_path=onnx_file, input_shape=dummy_input.shape,
                        device=device,
                        opset_version=11)
    # speedup the model with provided weight mask.If you use a wrapped model, don't forget to unwrap it.
    pruner._unwrap_model()
    # 将掩码应用到模型,模型会变得更小,推理延迟也会减小
    # apply_compression_results(model, mask_file, device)
    if not os.path.exists(mask_file): raise Exception("not found mask file:{}".format(mask_file))
    print("load mask file to speed up:{}".format(mask_file))
    speed_up = ModelSpeedup(model, dummy_input=dummy_input, masks_file=mask_file)
    speed_up.speedup_model()
    out = model(dummy_input)
    # speedup后模型的计算量和参数量
    flops, params, _ = count_flops_params(model, dummy_input, verbose=verbose)
    info += f"speedup-Model FLOPs {flops / 1e6:.2f}M, Params {params / 1e6:.2f}M\n"
    print(info)
    # finetune the model to recover the accuracy.
    return model

class SimpleModel(nn.Module):
    def __init__(self, num_classes, use_inorm=True):
        super(SimpleModel, self).__init__()
        self.use_inorm = use_inorm
        self.conv1 = nn.Conv2d(3, 32, 3)
        if self.use_inorm:
            self.inorm1 = nn.InstanceNorm2d(32, affine=False)
        self.conv2 = nn.Conv2d(32, 64, 3)
        if self.use_inorm:
            self.inorm2 = nn.InstanceNorm2d(64, affine=False)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.fc = nn.Linear(128, 256)
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        if self.use_inorm:
            x = self.inorm1(x)
        x = F.relu(self.conv2(x))
        if self.use_inorm:
            x = self.inorm2(x)
        x = F.relu(self.conv3(x))
        x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
        x = self.fc(x)
        x = self.classifier(x)
        return x

if __name__ == "__main__":
    device = "cuda:0"
    num_classes = 20
    input_size = [1, 3, 128, 128]
    # use_inorm=False可以正常pruning,
    # 但use_inorm=True,Has not supported replacing the module: `InstanceNorm2d`
    # model = SimpleModel(num_classes=num_classes, use_inorm=False)
    model = SimpleModel(num_classes=num_classes, use_inorm=True)
    model.eval()
    inputs = torch.randn(input_size)
    model = model.to(device)
    inputs = inputs.to(device)
    output = model(inputs)
    prune_model = copy.deepcopy(model)
    prune_model = model_pruning(prune_model, input_size=input_size, sparsity=0.2, dependency_aware=True, device=device)
    print("inputs:", inputs.shape)
    print("output:", output.shape)

@zheng-ningxin

ymkzpx commented 2 years ago

Hi, @PanJinquan : Can you show me the complete source code? I try to reproduce this problem and solve it.

PanJinquan commented 2 years ago

你好,大佬,已经更新了完整的代码 @zheng-ningxin