VainF / Torch-Pruning

[CVPR 2023] DepGraph: Towards Any Structural Pruning
https://arxiv.org/abs/2301.12900
MIT License
2.72k stars 335 forks source link

Does the tool support to prune the yolov5/v7/v8? #100

Closed songkq closed 1 year ago

songkq commented 1 year ago

@VainF Hi, when I prune the yolov7 backbone, it failed and said: group_imp = torch.stack(group_imp, dim=0) RuntimeError: stack expects each tensor to be equal size, but got [512] at entry 0 and [2048] at entry 2

Could you please give some kind advices?

VainF commented 1 year ago

Hi @songkq. torch.stack requires all inputs to have the same dimension but Torch-Pruning can not detect this internal restriction. Please use torch.cat if possible.

songkq commented 1 year ago

@VainF Hi, torch.stack is used in the line 97 of torch_pruning/importance.py for tp.importance.MagnitudeImportance(p=1).
When I set the pruning method as l1 for yolov7, it failed with the above error. It seems to be a bug. But when I set the pruning method as random, lamp or slim, it works well.

VainF commented 1 year ago

Thank you! Let me take a look.

labderrafie commented 1 year ago

@songkq Hi, I'm trying to prune yolov7 and I'm facing the same error, how did you change the pruning method so that it works? Thank you !

songkq commented 1 year ago

@labderrafie Just use thetp.importance.MagnitudeImportance(p=2) or tp.importance.BNScaleImportance() for pruning.

labderrafie commented 1 year ago

@songkq @VainF none of them seem to work for me... This is the code i use :

def main():
    # Global metrics
    # opt = Opts().parse()
    # model_path = opt.modelPath
    sparsity = 0.5
    shapes_list = [3, 640, 640]

    device = select_device('cpu')
    model = attempt_load('yolov7.pt', map_location=device)  # load FP32 model
    labels = model.names

    # print(model)

    c, w, h = shapes_list
    dummy_input = torch.randn(1, c, w, h)

    imp = tp.importance.MagnitudeImportance(p=1)
    # imp = tp.importance.BNScaleImportance()
    ignored_layers = []
    for m in model.modules():
        if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
            ignored_layers.append(m)

    iterative_steps = 1
    pruner = tp.pruner.MagnitudePruner(
        model,
        dummy_input,
        importance=imp,
        iterative_steps=iterative_steps,
        ch_sparsity=sparsity,
        ignored_layers=ignored_layers,
    )

    for i in range(iterative_steps):
        ori_size = tp.utils.count_params(model)
        pruner.step()
        print(
            "  Params: %.2i  => %.2i "
            % (ori_size , tp.utils.count_params(model))
        )

when i use tp.importance.MagnitudeImportance(p=1) I get the following error : group_imp = torch.stack(group_imp, dim=0) RuntimeError: stack expects each tensor to be equal size, but got [512] at entry 0 and [2048] at entry 1

Same error for tp.importance.MagnitudeImportance(p=2)

But with tp.importance.BNScaleImportance() it returns a different error : group_imp = torch.stack(group_imp, dim=0) RuntimeError: stack expects a non-empty TensorList

I don't know what version of yolo you succeeded to prune, i'm trying with yolov7 and I don't see what I'm doing wrong...

VainF commented 1 year ago

@labderrafie Could you provide the GitHub repository of your yolov7? Thanks!

labderrafie commented 1 year ago

@VainF i used the official yolov7 github repo : https://github.com/WongKinYiu/yolov7

I first cloned the repo :

!pip --quiet install onnx onnxruntime onnxsim !git clone https://github.com/WongKinYiu/yolov7.git %cd yolov7 !wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt

Then to load the model I used the same way they use in export.py to load yolov7 for ONNX export :

import argparse
import sys
import time
import warnings

sys.path.append('./')  # to run '$ python *.py' files in subdirectories

import torch
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile

import models
from models.experimental import attempt_load, End2End
from utils.activations import Hardswish, SiLU
from utils.general import set_logging, check_img_size
from utils.torch_utils import select_device
from utils.add_nms import RegisterNMS

# Load model
device = select_device('cpu')
model = attempt_load('yolov7.pt', map_location=device)  # load FP32 model

The following steps are for pruning, it's the code that I provided in my previous comment. Thank you for your help !

Ledgero commented 1 year ago

@labderrafie @VainF hi ,I'm trying to prune yolov7 too and I use the same code as you, so I'm facing the same error, have you fixed this bug so that it works? Could you please give some kind advices ? Thank you !

labderrafie commented 1 year ago

@Ledgero i didn't find a solution neither...

VainF commented 1 year ago

Hi All. The RandomImportance works with the attached yolov7 script (modified from the official detect.py). But there are still some problems with the MagnitudeImportance.

python detect_pruned.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inference/images/horses.jpg &>
 output.log

Outputs:

...
Before Pruning: MACs=6413721083.000000, #Params=36905341.000000
After Pruning: MACs=1639894907.000000, #Params=9347133.000000
...

The detect_pruned.py:

import argparse
import time
from pathlib import Path

import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
    scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized, TracedModel

import torch_pruning as tp

def detect(save_img=False):
    source, weights, view_img, save_txt, imgsz, trace = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, not opt.no_trace
    save_img = not opt.nosave and not source.endswith('.txt')  # save inference images
    webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
        ('rtsp://', 'rtmp://', 'http://', 'https://'))

    # Directories
    save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))  # increment run
    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

    # Initialize
    set_logging()
    device = select_device(opt.device)
    half = device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    model = attempt_load(weights, map_location=device)  # load FP32 model
    print(model)

    # Pruning
    example_inputs = torch.randn(1, 3, 224, 224).to(device)
    imp = tp.importance.RandomImportance()

    ignored_layers = []
    from models.yolo import Detect
    for m in model.modules():
        if isinstance(m, Detect):
            ignored_layers.append(m)
    print(ignored_layers)

    iterative_steps = 1 # progressive pruning
    pruner = tp.pruner.MagnitudePruner(
        model,
        example_inputs,
        importance=imp,
        iterative_steps=iterative_steps,
        ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
        ignored_layers=ignored_layers,
    )
    base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("Before Pruning: MACs=%f, #Params=%f"%(base_macs, base_nparams))

    pruner.step()
    pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("After Pruning: MACs=%f, #Params=%f"%(pruned_macs, pruned_nparams))
    print(model)

    stride = int(model.stride.max())  # model stride
    imgsz = check_img_size(imgsz, s=stride)  # check img_size

    if trace:
        model = TracedModel(model, device, opt.img_size)

    if half:
        model.half()  # to FP16

    # Second-stage classifier
    classify = False
    if classify:
        modelc = load_classifier(name='resnet101', n=2)  # initialize
        modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()

    # Set Dataloader
    vid_path, vid_writer = None, None
    if webcam:
        view_img = check_imshow()
        cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source, img_size=imgsz, stride=stride)
    else:
        dataset = LoadImages(source, img_size=imgsz, stride=stride)

    # Get names and colors
    names = model.module.names if hasattr(model, 'module') else model.names
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]

    # Run inference
    if device.type != 'cpu':
        model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run once
    old_img_w = old_img_h = imgsz
    old_img_b = 1

    t0 = time.time()
    for path, img, im0s, vid_cap in dataset:
        img = torch.from_numpy(img).to(device)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Warmup
        if device.type != 'cpu' and (old_img_b != img.shape[0] or old_img_h != img.shape[2] or old_img_w != img.shape[3]):
            old_img_b = img.shape[0]
            old_img_h = img.shape[2]
            old_img_w = img.shape[3]
            for i in range(3):
                model(img, augment=opt.augment)[0]

        # Inference
        t1 = time_synchronized()
        with torch.no_grad():   # Calculating gradients would cause a GPU memory leak
            pred = model(img, augment=opt.augment)[0]
        t2 = time_synchronized()

        # Apply NMS
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
        t3 = time_synchronized()

        # Apply Classifier
        if classify:
            pred = apply_classifier(pred, modelc, img, im0s)

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            if webcam:  # batch_size >= 1
                p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
            else:
                p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)

            p = Path(p)  # to Path
            save_path = str(save_dir / p.name)  # img.jpg
            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # img.txt
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # Write results
                for *xyxy, conf, cls in reversed(det):
                    if save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh)  # label format
                        with open(txt_path + '.txt', 'a') as f:
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')

                    if save_img or view_img:  # Add bbox to image
                        label = f'{names[int(cls)]} {conf:.2f}'
                        plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=1)

            # Print time (inference + NMS)
            print(f'{s}Done. ({(1E3 * (t2 - t1)):.1f}ms) Inference, ({(1E3 * (t3 - t2)):.1f}ms) NMS')

            # Stream results
            if view_img:
                cv2.imshow(str(p), im0)
                cv2.waitKey(1)  # 1 millisecond

            # Save results (image with detections)
            if save_img:
                if dataset.mode == 'image':
                    cv2.imwrite(save_path, im0)
                    print(f" The image with the result is saved in: {save_path}")
                else:  # 'video' or 'stream'
                    if vid_path != save_path:  # new video
                        vid_path = save_path
                        if isinstance(vid_writer, cv2.VideoWriter):
                            vid_writer.release()  # release previous video writer
                        if vid_cap:  # video
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        else:  # stream
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
                            save_path += '.mp4'
                        vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
                    vid_writer.write(im0)

    if save_txt or save_img:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
        #print(f"Results saved to {save_dir}{s}")

    print(f'Done. ({time.time() - t0:.3f}s)')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default='yolov7.pt', help='model.pt path(s)')
    parser.add_argument('--source', type=str, default='inference/images', help='source')  # file/folder, 0 for webcam
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--view-img', action='store_true', help='display results')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--update', action='store_true', help='update all models')
    parser.add_argument('--project', default='runs/detect', help='save results to project/name')
    parser.add_argument('--name', default='exp', help='save results to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    parser.add_argument('--no-trace', action='store_true', help='don`t trace model')
    opt = parser.parse_args()
    print(opt)
    #check_requirements(exclude=('pycocotools', 'thop'))

    #with torch.no_grad():
    if opt.update:  # update all models (to fix SourceChangeWarning)
        for opt.weights in ['yolov7.pt']:
            detect()
            strip_optimizer(opt.weights)
    else:
        detect()
VainF commented 1 year ago

Here is the pruned model:

Model(
  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (2): Conv(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (3): Conv(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (4): Conv(
      (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (5): Conv(
      (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (6): Conv(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (7): Conv(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (8): Conv(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (9): Conv(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (10): Concat()
    (11): Conv(
      (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (12): MP(
      (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (13): Conv(
      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (14): Conv(
      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (15): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (16): Concat()
    (17): Conv(
      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (18): Conv(
      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (19): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (20): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (21): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (22): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (23): Concat()
    (24): Conv(
      (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (25): MP(
      (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (26): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (27): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (28): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (29): Concat()
    (30): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (31): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (32): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (33): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (34): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (35): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (36): Concat()
    (37): Conv(
      (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (38): MP(
      (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (39): Conv(
      (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (40): Conv(
      (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (41): Conv(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (42): Concat()
    (43): Conv(
      (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (44): Conv(
      (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (45): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (46): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (47): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (48): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (49): Concat()
    (50): Conv(
      (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (51): SPPCSPC(
      (cv1): Conv(
        (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        (act): SiLU(inplace=True)
      )
      (cv2): Conv(
        (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        (act): SiLU(inplace=True)
      )
      (cv3): Conv(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (act): SiLU(inplace=True)
      )
      (cv4): Conv(
        (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (act): SiLU(inplace=True)
      )
      (m): ModuleList(
        (0): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
        (1): MaxPool2d(kernel_size=9, stride=1, padding=4, dilation=1, ceil_mode=False)
        (2): MaxPool2d(kernel_size=13, stride=1, padding=6, dilation=1, ceil_mode=False)
      )
      (cv5): Conv(
        (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        (act): SiLU(inplace=True)
      )
      (cv6): Conv(
        (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (act): SiLU(inplace=True)
      )
      (cv7): Conv(
        (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        (act): SiLU(inplace=True)
      )
    )
    (52): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (53): Upsample(scale_factor=2.0, mode=nearest)
    (54): Conv(
      (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (55): Concat()
    (56): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (57): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (58): Conv(
      (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (59): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (60): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (61): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (62): Concat()
    (63): Conv(
      (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (64): Conv(
      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (65): Upsample(scale_factor=2.0, mode=nearest)
    (66): Conv(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (67): Concat()
    (68): Conv(
      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (69): Conv(
      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (70): Conv(
      (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (71): Conv(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (72): Conv(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (73): Conv(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (74): Concat()
    (75): Conv(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (76): MP(
      (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (77): Conv(
      (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (78): Conv(
      (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (79): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (80): Concat()
    (81): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (82): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (83): Conv(
      (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (84): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (85): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (86): Conv(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (87): Concat()
    (88): Conv(
      (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (89): MP(
      (m): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (90): Conv(
      (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (91): Conv(
      (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (92): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (93): Concat()
    (94): Conv(
      (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (95): Conv(
      (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (96): Conv(
      (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (97): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (98): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (99): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act): SiLU(inplace=True)
    )
    (100): Concat()
    (101): Conv(
      (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
      (act): SiLU(inplace=True)
    )
    (102): RepConv(
      (act): SiLU(inplace=True)
      (rbr_reparam): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (103): RepConv(
      (act): SiLU(inplace=True)
      (rbr_reparam): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (104): RepConv(
      (act): SiLU(inplace=True)
      (rbr_reparam): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (105): Detect(
      (m): ModuleList(
        (0): Conv2d(128, 255, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(256, 255, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(512, 255, kernel_size=(1, 1), stride=(1, 1))
      )
    )
  )
)
Ledgero commented 1 year ago

Hi All. The RandomImportance works with the attached yolov7 script (modified from the official detect.py). But there are still some problems with the MagnitudeImportance.

python detect_pruned.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inference/images/horses.jpg &>
 output.log

Outputs:

...
Before Pruning: MACs=6413721083.000000, #Params=36905341.000000
After Pruning: MACs=1639894907.000000, #Params=9347133.000000
...

The detect_pruned.py:

import argparse
import time
from pathlib import Path

import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
    scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
from utils.plots import plot_one_box
from utils.torch_utils import select_device, load_classifier, time_synchronized, TracedModel

import torch_pruning as tp

def detect(save_img=False):
    source, weights, view_img, save_txt, imgsz, trace = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, not opt.no_trace
    save_img = not opt.nosave and not source.endswith('.txt')  # save inference images
    webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
        ('rtsp://', 'rtmp://', 'http://', 'https://'))

    # Directories
    save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))  # increment run
    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

    # Initialize
    set_logging()
    device = select_device(opt.device)
    half = device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    model = attempt_load(weights, map_location=device)  # load FP32 model
    print(model)

    # Pruning
    example_inputs = torch.randn(1, 3, 224, 224).to(device)
    imp = tp.importance.RandomImportance()

    ignored_layers = []
    from models.yolo import Detect
    for m in model.modules():
        if isinstance(m, Detect):
            ignored_layers.append(m)
    print(ignored_layers)

    iterative_steps = 1 # progressive pruning
    pruner = tp.pruner.MagnitudePruner(
        model,
        example_inputs,
        importance=imp,
        iterative_steps=iterative_steps,
        ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
        ignored_layers=ignored_layers,
    )
    base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("Before Pruning: MACs=%f, #Params=%f"%(base_macs, base_nparams))

    pruner.step()
    pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("After Pruning: MACs=%f, #Params=%f"%(pruned_macs, pruned_nparams))
    print(model)

    stride = int(model.stride.max())  # model stride
    imgsz = check_img_size(imgsz, s=stride)  # check img_size

    if trace:
        model = TracedModel(model, device, opt.img_size)

    if half:
        model.half()  # to FP16

    # Second-stage classifier
    classify = False
    if classify:
        modelc = load_classifier(name='resnet101', n=2)  # initialize
        modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()

    # Set Dataloader
    vid_path, vid_writer = None, None
    if webcam:
        view_img = check_imshow()
        cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source, img_size=imgsz, stride=stride)
    else:
        dataset = LoadImages(source, img_size=imgsz, stride=stride)

    # Get names and colors
    names = model.module.names if hasattr(model, 'module') else model.names
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]

    # Run inference
    if device.type != 'cpu':
        model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters())))  # run once
    old_img_w = old_img_h = imgsz
    old_img_b = 1

    t0 = time.time()
    for path, img, im0s, vid_cap in dataset:
        img = torch.from_numpy(img).to(device)
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Warmup
        if device.type != 'cpu' and (old_img_b != img.shape[0] or old_img_h != img.shape[2] or old_img_w != img.shape[3]):
            old_img_b = img.shape[0]
            old_img_h = img.shape[2]
            old_img_w = img.shape[3]
            for i in range(3):
                model(img, augment=opt.augment)[0]

        # Inference
        t1 = time_synchronized()
        with torch.no_grad():   # Calculating gradients would cause a GPU memory leak
            pred = model(img, augment=opt.augment)[0]
        t2 = time_synchronized()

        # Apply NMS
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
        t3 = time_synchronized()

        # Apply Classifier
        if classify:
            pred = apply_classifier(pred, modelc, img, im0s)

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            if webcam:  # batch_size >= 1
                p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
            else:
                p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)

            p = Path(p)  # to Path
            save_path = str(save_dir / p.name)  # img.jpg
            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # img.txt
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # Write results
                for *xyxy, conf, cls in reversed(det):
                    if save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh)  # label format
                        with open(txt_path + '.txt', 'a') as f:
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')

                    if save_img or view_img:  # Add bbox to image
                        label = f'{names[int(cls)]} {conf:.2f}'
                        plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=1)

            # Print time (inference + NMS)
            print(f'{s}Done. ({(1E3 * (t2 - t1)):.1f}ms) Inference, ({(1E3 * (t3 - t2)):.1f}ms) NMS')

            # Stream results
            if view_img:
                cv2.imshow(str(p), im0)
                cv2.waitKey(1)  # 1 millisecond

            # Save results (image with detections)
            if save_img:
                if dataset.mode == 'image':
                    cv2.imwrite(save_path, im0)
                    print(f" The image with the result is saved in: {save_path}")
                else:  # 'video' or 'stream'
                    if vid_path != save_path:  # new video
                        vid_path = save_path
                        if isinstance(vid_writer, cv2.VideoWriter):
                            vid_writer.release()  # release previous video writer
                        if vid_cap:  # video
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        else:  # stream
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
                            save_path += '.mp4'
                        vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
                    vid_writer.write(im0)

    if save_txt or save_img:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
        #print(f"Results saved to {save_dir}{s}")

    print(f'Done. ({time.time() - t0:.3f}s)')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default='yolov7.pt', help='model.pt path(s)')
    parser.add_argument('--source', type=str, default='inference/images', help='source')  # file/folder, 0 for webcam
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--view-img', action='store_true', help='display results')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--update', action='store_true', help='update all models')
    parser.add_argument('--project', default='runs/detect', help='save results to project/name')
    parser.add_argument('--name', default='exp', help='save results to project/name')
    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    parser.add_argument('--no-trace', action='store_true', help='don`t trace model')
    opt = parser.parse_args()
    print(opt)
    #check_requirements(exclude=('pycocotools', 'thop'))

    #with torch.no_grad():
    if opt.update:  # update all models (to fix SourceChangeWarning)
        for opt.weights in ['yolov7.pt']:
            detect()
            strip_optimizer(opt.weights)
    else:
        detect()

Hello, I used the code you provided,(The original size of the model is 75.6 M) but I ran into a problem when iterative_ steps = 1 ,pruning result is

Before Pruning: MACs=6413721083.000000, #Params=36905341.000000 After Pruning: MACs=1639894907.000000, #Params=9347133.000000

output model tracedmodel.pt=37.8 m when iterative steps = 5 ,pruning result is

Before Pruning: MACs=6413721083.000000, #Params=36905341.000000 After Pruning: MACs=5165346270.000000, #Params=29835453.000000

output model traced_model.pt=119.8 m

This makes me confused. I haven't changed ch Sparsity=0.5, but When the iterative steps become larger, the model becomes larger. Looking forward to your reply, thank you!

VainF commented 1 year ago

Hi @Ledgero, for iterative_steps=n, you need to repeat pruner.step() for n times. For example:

iterative_steps = 5 # progressive pruning
for i in range(iterative_steps):
    pruner.step()
    # finetune your model here
    # finetune(model)
    # ...
ducanhluu commented 1 year ago

Hi @VainF, can you look at yolov8 as well ? I have tried your above pruning code for yolov8 but it does not work. Here is the link to the official repo of yolov8 https://github.com/ultralytics/ultralytics

VainF commented 1 year ago

Hi @VainF, can you look at yolov8 as well ? I have tried your above pruning code for yolov8 but it does not work. Here is the link to the official repo of yolov8 https://github.com/ultralytics/ultralytics

We will try our best to support these awesome models.

HankYe commented 1 year ago

Hi, @songkq; I have successfully applied the _torchpruning tool in the pruning of YOLOv5. Maybe you can also refer to https://github.com/HankYe/PAGCP to get more helpful information. Hope my experiences can give you some help.

VainF commented 1 year ago

Hi, @songkq; I have successfully applied the _torchpruning tool in the pruning of YOLOv5. Maybe you can also refer to https://github.com/HankYe/PAGCP to get more helpful information. Hope my experiences can give you some help.

Wow! your work is truly impressive! Really thanks for the kind sharing.

Would you mind adding your repository's URL to the homepage of Torch-Pruning? I think it would be a useful resource for other people.

HankYe commented 1 year ago

Hi, @songkq; I have successfully applied the _torchpruning tool in the pruning of YOLOv5. Maybe you can also refer to https://github.com/HankYe/PAGCP to get more helpful information. Hope my experiences can give you some help.

Wow! your work is truly impressive! Really thanks for the kind sharing.

Would you mind adding your repository's URL to the homepage of Torch-Pruning? I think it would be a useful resource for other people. No, do as you like. But one thing to note is that the _torchpruning version in my repository is 0.2.7, so it may be different from the latest version in some cases. Finally, thanks for your great work as well!

Ledgero commented 1 year ago

Hi, @songkq; I have successfully applied the _torchpruning tool in the pruning of YOLOv5. Maybe you can also refer to https://github.com/HankYe/PAGCP to get more helpful information. Hope my experiences can give you some help.

oh! Thank you very much for sharing. This will be of great help for me. @HankYe

HankYe commented 1 year ago

Hi, @songkq; I have successfully applied the _torchpruning tool in the pruning of YOLOv5. Maybe you can also refer to https://github.com/HankYe/PAGCP to get more helpful information. Hope my experiences can give you some help.

oh! Thank you very much for sharing. This will be of great help for me. @HankYe

My pleasure! Also, thanks for the great work of _torchpruning again!

songkq commented 1 year ago

@HankYe @VainF Good job. Thanks. https://github.com/tianyic/only_train_once is also a good choice.

VainF commented 1 year ago

@HankYe @VainF Good job. Thanks. https://github.com/tianyic/only_train_once is also a good choice.

Thank you!

VainF commented 1 year ago

Hi All, the magnitude pruner has been fixed for yolov7. A simple pruning script can be found at benchmarks/prunability/yolov7_detect_pruned.py

labderrafie commented 1 year ago

@VainF just tested it, all is good. thanks for this great work !

songkq commented 1 year ago

@HankYe @VainF Hi, is there a deterministic way to quickly estimate the maximum pruning ratio that can be set while considering that the pruned model accuracy is almost unaffected compared with that of before pruning?

VainF commented 1 year ago

Hello @songkq. There is no one-size-fits-all criterion for determining the optimal pruning ratio. However, one approach we could take is to gradually prune the model using small pruning ratios, until we achieve the desired trade-off between accuracy and efficiency.

HankYe commented 1 year ago

@HankYe @VainF Hi, is there a deterministic way to quickly estimate the maximum pruning ratio that can be set while considering that the pruned model accuracy is almost unaffected compared with that of before pruning?

Hi, @songkq. I agree with the argument of @VainF that the optimal pruning ratio varies in different structures. And an empirical observation in our experiments is that the global maximum pruning ratio can reach 30%~40% for the downstream models with negligible performance drop, while for Resnets, i.e., the classification models, it can reach 50%+ with negligible performance drop. But the optimal layer-wise sparsity ratio remains challenging for manually-designed criteria, to my best knowledge. Hope the reply can deliver you some insights.

songkq commented 1 year ago

@VainF @HankYe Thanks for your kind advice. I'm wondering if we can construct a model to predict the optimal pruning ratio for various models in different downstream tasks according the typical distribution of the model weight. Just like the idea of once-for-all(https://arxiv.org/pdf/1908.09791.pdf) to predict the sub-net accuracy of a super-network for diverse deployment platforms, and the idea of NNLQP to predict the target latency (A Multi-Platform Neural Network Latency Query and Prediction System https://github.com/ModelTC/NNLQP).

VainF commented 1 year ago

@VainF @HankYe Thanks for your kind advice. I'm wondering if we can construct a model to predict the optimal pruning ratio for various models in different downstream tasks according the typical distribution of the model weight. Just like the idea of once-for-all(https://arxiv.org/pdf/1908.09791.pdf) to predict the sub-net accuracy of a super-network for diverse deployment platforms, and the idea of NNLQP to predict the target latency (A Multi-Platform Neural Network Latency Query and Prediction System https://github.com/ModelTC/NNLQP).

Thank you for sharing the great idea. I believe that training a predictor network to assess the importance of weights is definitely feasible. However, there are still some obstacles that need to be overcome, such as the lack of a suitable "model" dataset. One possible solution could be to download open-source models from model hubs like Hugging Face, which would provide a diverse set of models to train and test the predictor network.

HankYe commented 1 year ago

@VainF @HankYe Thanks for your kind advice. I'm wondering if we can construct a model to predict the optimal pruning ratio for various models in different downstream tasks according the typical distribution of the model weight. Just like the idea of once-for-all(https://arxiv.org/pdf/1908.09791.pdf) to predict the sub-net accuracy of a super-network for diverse deployment platforms, and the idea of NNLQP to predict the target latency (A Multi-Platform Neural Network Latency Query and Prediction System https://github.com/ModelTC/NNLQP).

Hi, @songkq. I also think it is possible to construct such a unified predictor working as the saliency criterion to predict the influence of different pruning ratios for various downstream models on the model performance. Personally, such a framework may be an ideal but demanding paradigm in network pruning since there are no explicit ground-truth labels for the predictor but only the trade-off between the model performance and the pruning ratio. Thus, it is hard to absolutely define which compression result is the optimal choice. Once the criterion of the optimal pruning results is determined, I think this method would be achievable.