VainF / Torch-Pruning

[CVPR 2023] DepGraph: Towards Any Structural Pruning
https://arxiv.org/abs/2301.12900
MIT License
2.66k stars 329 forks source link

Pruning yolov8 segmentation model #228

Open Ivantheginger opened 1 year ago

Ivantheginger commented 1 year ago

I'm trying to run this script https://github.com/VainF/Torch-Pruning/tree/master/examples/yolov8 on segmentation model (yolov8n-seg). I added this lines at the start of prune()

    pruning_cfg['task'] = args.task = 'segment'
    batch_size = pruning_cfg['batch'] = 4
    pruning_cfg['imgsz'] = 512 

I get this error during training

Traceback (most recent call last):
  File "/home/ivan/Projects/ultralytics/yolov8_pruning.py", line 402, in <module>
    prune(args)
  File "/home/ivan/Projects/ultralytics/yolov8_pruning.py", line 362, in prune
    model.train_v2(pruning=True, **pruning_cfg)
  File "/home/ivan/Projects/ultralytics/yolov8_pruning.py", line 268, in train_v2
    self.trainer.train()
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/ultralytics/yolo/engine/trainer.py", line 191, in train
    self._do_train(world_size)
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/ultralytics/yolo/engine/trainer.py", line 364, in _do_train
    self.metrics, self.fitness = self.validate()
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/ultralytics/yolo/engine/trainer.py", line 464, in validate
    metrics = self.validator(self)
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/.venv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/ultralytics/yolo/engine/validator.py", line 169, in __call__
    self.update_metrics(preds, batch)
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/ultralytics/yolo/v8/segment/val.py", line 82, in update_metrics
    pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=batch['img'][si].shape[1:])
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/ultralytics/yolo/utils/ops.py", line 596, in process_mask
    masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)  # CHW
RuntimeError: mat1 and mat2 shapes cannot be multiplied (300x111 and 32x13056)

I also tried to use latest ultralytics repo version, ( I took lines that were added in train_v2 and added them to train from main branch)

def train_v2(self, trainer=None, pruning=False, **kwargs):
        """
        Trains the model on a given dataset.

        Args:
            trainer (BaseTrainer, optional): Customized trainer.
            **kwargs (Any): Any number of arguments representing the training configuration.
        """
        self._check_is_pytorch_model()
        if self.session:  # Ultralytics HUB session
            if any(kwargs):
                LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
            kwargs = self.session.train_args
        overrides = self.overrides.copy()
        if kwargs.get('cfg'):
            LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
            overrides = yaml_load(check_yaml(kwargs['cfg']))
        overrides.update(kwargs)
        overrides['mode'] = 'train'
        if not overrides.get('data'):
            raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
        if overrides.get('resume'):
            overrides['resume'] = self.ckpt_path
        self.task = overrides.get('task') or self.task
        trainer = trainer or self.smart_load('trainer')
        self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
        if not overrides.get('resume'):  # manually set model only if not resuming
            self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
            self.model = self.trainer.model

        # ==== prune
        if not pruning:
            if not overrides.get('resume'):  # manually set model only if not resuming
                self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
                self.model = self.trainer.model
        else:
            # pruning mode
            self.trainer.pruning = True
            self.trainer.model = self.model

            # replace some functions to disable half precision saving
            self.trainer.save_model = save_model_v2.__get__(self.trainer)
            self.trainer.final_eval = final_eval_v2.__get__(self.trainer)
        # ==== prune

        self.trainer.hub_session = self.session  # attach optional HUB session
        self.trainer.train()
        # Update model and cfg after training
        if RANK in (-1, 0):
            self.model, _ = attempt_load_one_weight(str(self.trainer.best))
            self.overrides = self.model.args
            self.metrics = getattr(self.trainer.validator, 'metrics', None)  # TODO: no metrics returned by DDP

With this train_v2 script works for latest yolov8 with task=detect, but when i try to run it with task=segment i get another error. This error occurs on second iteration of pruning.

PRUNER STEP
Traceback (most recent call last):
  File "/home/ivan/Projects/ultralytics/yolov8_pruning 2.py", line 456, in <module>
    prune(args)
  File "/home/ivan/Projects/ultralytics/yolov8_pruning 2.py", line 398, in prune
    pruner.step()
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/.venv/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 162, in step
    for group in pruning_fn():
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/.venv/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 212, in prune_local
    imp = self.estimate_importance(group, ch_groups=ch_groups)
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/.venv/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 166, in estimate_importance
    return self.importance(group, ch_groups=ch_groups)
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/.venv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/media/ivan/68F413A6F4137592/Projects/ultralytics/.venv/lib/python3.9/site-packages/torch_pruning/pruner/importance.py", line 145, in __call__
    local_imp = local_imp[idxs]
IndexError: index 768 is out of bounds for dimension 0 with size 384

Is there a way to make it work for segmentation?

apanand14 commented 1 year ago

Hello!! Is there any update? I'm also looking for the same.