VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.59k stars 318 forks source link

AttributeError: Can't get attribute '__main__' on <module 'builtins' (built-in)> while loading the pruned model #399

Open jenyaMi opened 2 months ago

jenyaMi commented 2 months ago

I use yolov8 pruning code, to be more precisely the updated version for the new version of ultralytics: https://github.com/chbw818/yolov8-prune-using-torch-pruning-/tree/main

I am struggling with loading the pruned model...

AttributeError: Can't get attribute 'main' on <module 'builtins' (built-in)>

AttributeError Traceback (most recent call last)

in 16 args = parser.parse_args() 17 ---> 18 prune(args) in prune(args) 432 pruning_cfg['name'] = f"step_{i}_finetune" 433 pruning_cfg['batch'] = batch_size # restore batch size --> 434 model.train_v2(pruning=True, **pruning_cfg) 435 436 # post fine-tuning validation in train_v2(self, pruning, **kwargs) 337 338 self.trainer.hub_session = self.session # attach optional HUB session --> 339 self.trainer.train() 340 # Update model and cfg after training 341 if RANK in (-1, 0): /workspace/data/notebooks/belt_detection/ultralytics/ultralytics/engine/trainer.py in train(self) 202 203 else: --> 204 self._do_train(world_size) 205 206 def _setup_scheduler(self): /workspace/data/notebooks/belt_detection/ultralytics/ultralytics/engine/trainer.py in _do_train(self, world_size) 467 f"{(time.time() - self.train_time_start) / 3600:.3f} hours." 468 ) --> 469 self.final_eval() 470 if self.args.plots: 471 self.plot_metrics() in final_eval_v2(self) 272 for f in self.last, self.best: 273 if f.exists(): --> 274 strip_optimizer_v2(f) # strip optimizers 275 if f is self.best: 276 LOGGER.info(f'\nValidating {f}...') in strip_optimizer_v2(f, s) 284 Disabled half precision saving. originated from ultralytics/yolo/utils/torch_utils.py 285 """ --> 286 x = torch.load(f, map_location=torch.device('cpu')) 287 args = {**DEFAULT_CFG_DICT, **x['train_args']} # combine model args with default args, preferring model args 288 if x.get('ema'): /opt/conda/lib/python3.8/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args) 605 opened_file.seek(orig_position) 606 return torch.jit.load(opened_file) --> 607 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args) 608 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) /opt/conda/lib/python3.8/site-packages/torch/serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args) 880 unpickler = UnpicklerWrapper(data_file, **pickle_load_args) 881 unpickler.persistent_load = persistent_load --> 882 result = unpickler.load() 883 884 torch._utils._validate_loaded_sparse_tensors() /opt/conda/lib/python3.8/site-packages/torch/serialization.py in find_class(self, mod_name, name) 873 def find_class(self, mod_name, name): 874 mod_name = load_module_mapping.get(mod_name, mod_name) --> 875 return super().find_class(mod_name, name) 876 877 # Load the data (which may in turn use `persistent_load` to load tensors) AttributeError: Can't get attribute '__main__' on Maybe someone know what could be the problem here, I saw many people were able to run the code successfully. I use the last version of ultralytics 8.2.50
Ramzes30765 commented 1 month ago

Same for me. If you'll find any solution - please tag me

jenyaMi commented 1 month ago

@Ramzes30765 hey Adding pickle_module=dill explicitly in the torch.load function like this: torch.load('your_model.pt', pickle_module=dill) helped me. I added it in the strip_optimizer_v2 function and the torch_save_load function in the ultralytics/nn/tasks.py file. let me know if you found something else

Ramzes30765 commented 1 month ago

@jenyaMi Hello! Thank you for information

pickle_module=dill Should I add this parameter in torch.save() too?

Ramzes30765 commented 1 month ago

@

@Ramzes30765 hey Adding pickle_module=dill explicitly in the torch.load function like this: torch.load('your_model.pt', pickle_module=dill) helped me. I added it in the strip_optimizer_v2 function and the torch_save_load function in the ultralytics/nn/tasks.py file. let me know if you found something else

Still have an errors(

Image sizes 640 train, 640 val
Using 8 dataloader workers
Logging results to runs/detect/step_1_finetune
Starting training for 10 epochs...
Closing dataloader mosaic

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
  0%|          | 0/8 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/dalekseenko/models_optimizing/ultralytics/yolov8_pruning.py", line 396, in <module>
    prune(args)
  File "/home/dalekseenko/models_optimizing/ultralytics/yolov8_pruning.py", line 357, in prune
    model.train_v2(pruning=True, **pruning_cfg)
  File "/home/dalekseenko/models_optimizing/ultralytics/yolov8_pruning.py", line 266, in train_v2
    self.trainer.train()
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/engine/trainer.py", line 204, in train
    self._do_train(world_size)
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/engine/trainer.py", line 381, in _do_train
    self.loss, self.loss_items = self.model(batch)
                                 ^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/nn/tasks.py", line 101, in forward
    return self.loss(x, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/nn/tasks.py", line 283, in loss
    return self.criterion(preds, batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/utils/loss.py", line 229, in __call__
    pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/utils/loss.py", line 201, in bbox_decode
    pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat in method wrapper_CUDA_addmv_)
Exception in thread Thread-42 (_pin_memory_loop):
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1073, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.12/threading.py", line 1010, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dalekseenko/models_optimizing/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 54, in _pin_memory_loop
    do_one_step()
  File "/home/dalekseenko/models_optimizing/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 31, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/.venv/lib/python3.12/site-packages/torch/multiprocessing/reductions.py", line 495, in rebuild_storage_fd
    fd = df.detach()
         ^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/resource_sharer.py", line 86, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 525, in Client
    answer_challenge(c, authkey)
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 953, in answer_challenge
    message = connection.recv_bytes(256)         # reject large message
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 430, in _recv_bytes
    buf = self._recv(4)
          ^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 395, in _recv
    chunk = read(handle, remaining)
            ^^^^^^^^^^^^^^^^^^^^^^^
ConnectionResetError: [Errno 104] Connection reset by peer
YMCAlan commented 1 month ago

I encountered the same issue, but after making modifications, it was able to run successfully. However, I am now facing an "index error".

# This code is adapted from Issue [#147](https://github.com/VainF/Torch-Pruning/issues/147), implemented by @Hyunseok-Kim0.
import argparse
import math
import os
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import List, Union

import dill
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from ultralytics import YOLO, __version__
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.nn.modules import Detect, C2f, Conv, Bottleneck
from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.utils import yaml_load, LOGGER, RANK, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS
from ultralytics.utils import (
    ARGV,
    ASSETS,
    DEFAULT_CFG_DICT,
    LOGGER,
    RANK,
    callbacks,
    checks,
    emojis,
    yaml_load,
)
from ultralytics.utils.torch_utils import initialize_weights, de_parallel
from ultralytics.engine.model import Model
from ultralytics.models.yolo.model import DetectionModel
import torch_pruning as tp

def save_pruning_performance_graph(x, y1, y2, y3):
    """
    Draw performance change graph
    Parameters
    ----------
    x : List
        Parameter numbers of all pruning steps
    y1 : List
        mAPs after fine-tuning of all pruning steps
    y2 : List
        MACs of all pruning steps
    y3 : List
        mAPs after pruning (not fine-tuned) of all pruning steps

    Returns
    -------

    """
    try:
        plt.style.use("ggplot")
    except:
        pass

    x, y1, y2, y3 = np.array(x), np.array(y1), np.array(y2), np.array(y3)
    y2_ratio = y2 / y2[0]

    # create the figure and the axis object
    fig, ax = plt.subplots(figsize=(8, 6))

    # plot the pruned mAP and recovered mAP
    ax.set_xlabel('Pruning Ratio')
    ax.set_ylabel('mAP')
    ax.plot(x, y1, label='recovered mAP')
    ax.scatter(x, y1)
    ax.plot(x, y3, color='tab:gray', label='pruned mAP')
    ax.scatter(x, y3, color='tab:gray')

    # create a second axis that shares the same x-axis
    ax2 = ax.twinx()

    # plot the second set of data
    ax2.set_ylabel('MACs')
    ax2.plot(x, y2_ratio, color='tab:orange', label='MACs')
    ax2.scatter(x, y2_ratio, color='tab:orange')

    # add a legend
    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines + lines2, labels + labels2, loc='best')

    ax.set_xlim(105, -5)
    ax.set_ylim(0, max(y1) + 0.05)
    ax2.set_ylim(0.05, 1.05)

    # calculate the highest and lowest points for each set of data
    max_y1_idx = np.argmax(y1)
    min_y1_idx = np.argmin(y1)
    max_y2_idx = np.argmax(y2)
    min_y2_idx = np.argmin(y2)
    max_y1 = y1[max_y1_idx]
    min_y1 = y1[min_y1_idx]
    max_y2 = y2_ratio[max_y2_idx]
    min_y2 = y2_ratio[min_y2_idx]

    # add text for the highest and lowest values near the points
    ax.text(x[max_y1_idx], max_y1 - 0.05, f'max mAP = {max_y1:.2f}', fontsize=10)
    ax.text(x[min_y1_idx], min_y1 + 0.02, f'min mAP = {min_y1:.2f}', fontsize=10)
    ax2.text(x[max_y2_idx], max_y2 - 0.05, f'max MACs = {max_y2 * y2[0] / 1e9:.2f}G', fontsize=10)
    ax2.text(x[min_y2_idx], min_y2 + 0.02, f'min MACs = {min_y2 * y2[0] / 1e9:.2f}G', fontsize=10)

    plt.title('Comparison of mAP and MACs with Pruning Ratio')
    plt.savefig('pruning_perf_change.png')

def infer_shortcut(bottleneck):
    c1 = bottleneck.cv1.conv.in_channels
    c2 = bottleneck.cv2.conv.out_channels
    return c1 == c2 and hasattr(bottleneck, 'add') and bottleneck.add

class C2f_v2(nn.Module):
    # CSP Bottleneck with 2 convolutions
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv0 = Conv(c1, self.c, 1, 1)
        self.cv1 = Conv(c1, self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))

    def forward(self, x):
        # y = list(self.cv1(x).chunk(2, 1))
        y = [self.cv0(x), self.cv1(x)]
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

def transfer_weights(c2f, c2f_v2):
    c2f_v2.cv2 = c2f.cv2
    c2f_v2.m = c2f.m

    state_dict = c2f.state_dict()
    state_dict_v2 = c2f_v2.state_dict()

    # Transfer cv1 weights from C2f to cv0 and cv1 in C2f_v2
    old_weight = state_dict['cv1.conv.weight']
    half_channels = old_weight.shape[0] // 2
    state_dict_v2['cv0.conv.weight'] = old_weight[:half_channels]
    state_dict_v2['cv1.conv.weight'] = old_weight[half_channels:]

    # Transfer cv1 batchnorm weights and buffers from C2f to cv0 and cv1 in C2f_v2
    for bn_key in ['weight', 'bias', 'running_mean', 'running_var']:
        old_bn = state_dict[f'cv1.bn.{bn_key}']
        state_dict_v2[f'cv0.bn.{bn_key}'] = old_bn[:half_channels]
        state_dict_v2[f'cv1.bn.{bn_key}'] = old_bn[half_channels:]

    # Transfer remaining weights and buffers
    for key in state_dict:
        if not key.startswith('cv1.'):
            state_dict_v2[key] = state_dict[key]

    # Transfer all non-method attributes
    for attr_name in dir(c2f):
        attr_value = getattr(c2f, attr_name)
        if not callable(attr_value) and '_' not in attr_name:
            setattr(c2f_v2, attr_name, attr_value)

    c2f_v2.load_state_dict(state_dict_v2)

def replace_c2f_with_c2f_v2(module):
    for name, child_module in module.named_children():
        if isinstance(child_module, C2f):
            # Replace C2f with C2f_v2 while preserving its parameters
            shortcut = infer_shortcut(child_module.m[0])
            c2f_v2 = C2f_v2(child_module.cv1.conv.in_channels, child_module.cv2.conv.out_channels,
                            n=len(child_module.m), shortcut=shortcut,
                            g=child_module.m[0].cv2.conv.groups,
                            e=child_module.c / child_module.cv2.conv.out_channels)
            transfer_weights(child_module, c2f_v2)
            setattr(module, name, c2f_v2)
        else:
            replace_c2f_with_c2f_v2(child_module)

def save_model_v2(self: BaseTrainer):
    """Save model training checkpoints with additional metadata."""
    import io

    import pandas as pd  # scope for faster 'import ultralytics'

    # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
    buffer = io.BytesIO()
    torch.save(
        {
            "epoch": self.epoch,
            "best_fitness": self.best_fitness,
            "model": None,  # resume and final checkpoints derive from EMA
            "ema": deepcopy(self.ema.ema),
            "updates": self.ema.updates,
            "optimizer": deepcopy(self.optimizer.state_dict()),
            "train_args": vars(self.args),  # save as dict
            "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
            "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
            "date": datetime.now().isoformat(),
            "version": __version__,
            "license": "AGPL-3.0 (https://ultralytics.com/license)",
            "docs": "https://docs.ultralytics.com",
        },
        buffer,
    )
    serialized_ckpt = buffer.getvalue()  # get the serialized content to save

    # Save checkpoints
    self.last.write_bytes(serialized_ckpt)  # save last.pt
    if self.best_fitness == self.fitness:
        self.best.write_bytes(serialized_ckpt)  # save best.pt
    if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
        (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt)  # save epoch, i.e. 'epoch3.pt'

def final_eval_v2(self: BaseTrainer):
    """Performs final evaluation and validation for object detection YOLO model."""
    for f in self.last, self.best:
        if f.exists():
            strip_optimizer_v2(f)  # strip optimizers
            if f is self.best:
                LOGGER.info(f"\nValidating {f}...")
                self.validator.args.plots = self.args.plots
                self.metrics = self.validator(model=f)
                self.metrics.pop("fitness", None)
                self.run_callbacks("on_fit_epoch_end")

def strip_optimizer_v2(f: Union[str, Path] = "best.pt", s: str = "") -> None:
    """
    Strip optimizer from 'f' to finalize training, optionally save as 's'.

    Args:
        f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
        s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.

    Returns:
        None

    Example:
        ```python
        from pathlib import Path
        from ultralytics.utils.torch_utils import strip_optimizer

        for f in Path('path/to/model/checkpoints').rglob('*.pt'):
            strip_optimizer(f)
"""
try:
    x = torch.load(f, map_location=torch.device("cpu"))
    assert isinstance(x, dict), "checkpoint is not a Python dictionary"
    assert "model" in x, "'model' missing from checkpoint"
except Exception as e:
    LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
    return

updates = {
    "date": datetime.now().isoformat(),
    "version": __version__,
    "license": "AGPL-3.0 License (https://ultralytics.com/license)",
    "docs": "https://docs.ultralytics.com",
}

# Update model
if x.get("ema"):
    x["model"] = x["ema"]  # replace model with EMA
if hasattr(x["model"], "args"):
    x["model"].args = dict(x["model"].args)  # convert from IterableSimpleNamespace to dict
if hasattr(x["model"], "criterion"):
    x["model"].criterion = None  # strip loss criterion
# x["model"].half()  # to FP16
for p in x["model"].parameters():
    p.requires_grad = False

# Update other keys
args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})}  # combine args
for k in "optimizer", "best_fitness", "ema", "updates":  # keys
    x[k] = None
x["epoch"] = -1
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # strip non-default keys
# x['model'].args = x['train_args']

# Save
torch.save({**updates, **x}, s or f, use_dill=False)  # combine dicts (prefer to the right)
mb = os.path.getsize(s or f) / 1e6  # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")

def train_v2(self: YOLO, trainer=None, pruning=False, **kwargs): """ Disabled loading new model when pruning flag is set. originated from ultralytics/yolo/engine/model.py """

self._check_is_pytorch_model()
if hasattr(self.session, "model") and self.session.model.id:  # Ultralytics HUB session with loaded model
    if any(kwargs):
        LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
    kwargs = self.session.train_args  # overwrite kwargs

checks.check_pip_update_available()

overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
custom = {
    # NOTE: handle the case when 'cfg' includes 'data'.
    "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task],
    "model": self.overrides["model"],
    "task": self.task,
}  # method defaults
args = {**overrides, **custom, **kwargs, "mode": "train"}  # highest priority args on the right
if args.get("resume"):
    args["resume"] = self.ckpt_path

self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)

if not pruning:
    if not overrides.get('resume'):  # manually set model only if not resuming
        self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
        self.model = self.trainer.model

else:
    # pruning mode
    self.trainer.pruning = True
    self.trainer.model = self.model

    # replace some functions to disable half precision saving
    self.trainer.__setattr__("save_model", save_model_v2.__get__(self.trainer))
    self.trainer.__setattr__("final_eval", final_eval_v2.__get__(self.trainer))
    # self.trainer.save_model = save_model_v2.__get__(self.trainer)
    # self.trainer.final_eval = final_eval_v2.__get__(self.trainer)

self.trainer.hub_session = self.session  # attach optional HUB session
self.trainer.train()
# Update model and cfg after training
if RANK in {-1, 0}:
    ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
    self.model, _ = attempt_load_one_weight(ckpt)
    self.overrides = self.model.args
    self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP
return self.metrics

def prune(args):

load trained yolov8 model

model = YOLO(args.model)
model.__setattr__("train", train_v2.__get__(model))
pruning_cfg = yaml_load(checks.check_yaml(args.cfg))
batch_size = pruning_cfg['batch']

# use coco128 dataset for 10 epochs fine-tuning each pruning iteration step
# this part is only for sample code, number of epochs should be included in config file
pruning_cfg['data'] = "data.yaml"
pruning_cfg['epochs'] = 1

model.model.train()
replace_c2f_with_c2f_v2(model.model)
initialize_weights(model.model)  # set BN.eps, momentum, ReLU.inplace

for name, param in model.model.named_parameters():
    param.requires_grad = True

example_inputs = torch.randn(1, 3, pruning_cfg["imgsz"], pruning_cfg["imgsz"]).to(model.device)
macs_list, nparams_list, map_list, pruned_map_list = [], [], [], []
base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)

# do validation before pruning model
pruning_cfg['name'] = f"baseline_val"
pruning_cfg['batch'] = 1
validation_model = deepcopy(model)
metric = validation_model.val(**pruning_cfg)
init_map = metric.box.map
macs_list.append(base_macs)
nparams_list.append(100)
map_list.append(init_map)
pruned_map_list.append(init_map)
print(f"Before Pruning: MACs={base_macs / 1e9: .5f} G, #Params={base_nparams / 1e6: .5f} M, mAP={init_map: .5f}")

# prune same ratio of filter based on initial size
pruning_ratio = 1 - math.pow((1 - args.target_prune_rate), 1 / args.iterative_steps)

for i in range(args.iterative_steps):
    model.info()
    model.model.train()
    for name, param in model.model.named_parameters():
        param.requires_grad = True

    ignored_layers = []
    unwrapped_parameters = []
    for m in model.model.modules():
        if isinstance(m, (Detect,)):
            ignored_layers.append(m)

    example_inputs = example_inputs.to(model.device)
    pruner = tp.pruner.MagnitudePruner(
        model.model,
        example_inputs,
        importance=tp.importance.MagnitudeImportance(p=2),
        iterative_steps=1,
        pruning_ratio=pruning_ratio,  # remove 50% channels
        ignored_layers=ignored_layers,
        unwrapped_parameters=unwrapped_parameters
    )

    # Test regularization
    # output = model.model(example_inputs)
    # (output[0].sum() + sum([o.sum() for o in output[1]])).backward()
    # pruner.regularize(model.model)

    pruner.step()

    # pre fine-tuning validation
    pruning_cfg['name'] = f"step_{i}_pre_val"
    pruning_cfg['batch'] = 1
    validation_model.model = deepcopy(model.model)
    metric = validation_model.val(**pruning_cfg)
    pruned_map = metric.box.map
    pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs.to(model.device))
    current_speed_up = float(macs_list[0]) / pruned_macs
    print(f"After pruning iter {i + 1}: MACs={pruned_macs / 1e9} G, #Params={pruned_nparams / 1e6} M, "
          f"mAP={pruned_map}, speed up={current_speed_up}")

    # fine-tuning
    for name, param in model.model.named_parameters():
        param.requires_grad = True
    pruning_cfg['name'] = f"step_{i}_finetune"
    pruning_cfg['batch'] = batch_size  # restore batch size
    model.train(**pruning_cfg)

    # post fine-tuning validation
    pruning_cfg['name'] = f"step_{i}_post_val"
    pruning_cfg['batch'] = 1
    validation_model = YOLO(model.trainer.best)
    metric = validation_model.val(**pruning_cfg)
    current_map = metric.box.map
    print(f"After fine tuning mAP={current_map}")

    macs_list.append(pruned_macs)
    nparams_list.append(pruned_nparams / base_nparams * 100)
    pruned_map_list.append(pruned_map)
    map_list.append(current_map)

    # remove pruner after single iteration
    del pruner

    save_pruning_performance_graph(nparams_list, map_list, macs_list, pruned_map_list)

    if init_map - current_map > args.max_map_drop:
        print("Pruning early stop")
        break

model.export(format='ncnn', half=True)

if name == "main": parser = argparse.ArgumentParser() parser.add_argument('--model', default='runs/drink_auto/weights/best.pt', help='Pretrained pruning target model file') parser.add_argument('--cfg', default='default.yaml', help='Pruning config file.' ' This file should have same format with ultralytics/yolo/cfg/default.yaml') parser.add_argument('--iterative-steps', default=16, type=int, help='Total pruning iteration step') parser.add_argument('--target-prune-rate', default=0.5, type=float, help='Target pruning rate') parser.add_argument('--max-map-drop', default=1.0, type=float, help='Allowed maximum map drop after fine-tuning')

args = parser.parse_args()

prune(args)