VainF / Torch-Pruning

[CVPR 2023] DepGraph: Towards Any Structural Pruning
https://arxiv.org/abs/2301.12900
MIT License
2.73k stars 335 forks source link

RecursionError during pruning #43

Open nikhil153 opened 3 years ago

nikhil153 commented 3 years ago

Hi - thanks for a wonderful tool. I am trying to test it out with a pretrained model from here. However I am encountering the following error:

module name: encode1.conv0
pruning_idxs: [4, 5, 8, 9, 13, 14, 16, 17, 18, 19, 21, 23, 24, 26, 28, 30, 31, 32, 35, 36, 37, 38, 39, 40, 41, 43, 46, 48, 49, 53, 56, 58]
Traceback (most recent call last):
  File "/home/nikhil/projects/green_comp_neuro/FastSurfer/FastSurferCNN/torch_prune_test.py", line 159, in <module>
    load_pretrained(pretrained_ckpt, params_model, model, dummy_data, save_path)
  File "/home/nikhil/projects/green_comp_neuro/FastSurfer/FastSurferCNN/torch_prune_test.py", line 120, in load_pretrained
    model = torch_prune(model, dummy_data, params_model['prune_type'], params_model['prune_percent'])
  File "/home/nikhil/projects/green_comp_neuro/FastSurfer/FastSurferCNN/torch_prune_test.py", line 92, in torch_prune
    pruning_plan = DG.get_pruning_plan( module, tp.prune_conv, idxs=pruning_idxs )
  File "../../Torch-Pruning/torch_pruning/dependency.py", line 398, in get_pruning_plan
    _fix_denpendency_graph(root_node, pruning_fn, idxs)
  File "../../Torch-Pruning/torch_pruning/dependency.py", line 397, in _fix_denpendency_graph
    _fix_denpendency_graph(dep.broken_node, dep.handler, new_indices)
  File "../../Torch-Pruning/torch_pruning/dependency.py", line 397, in _fix_denpendency_graph
    _fix_denpendency_graph(dep.broken_node, dep.handler, new_indices)
  File "../../Torch-Pruning/torch_pruning/dependency.py", line 397, in _fix_denpendency_graph
    _fix_denpendency_graph(dep.broken_node, dep.handler, new_indices)
  [Previous line repeated 990 more times]
  File "../../Torch-Pruning/torch_pruning/dependency.py", line 387, in _fix_denpendency_graph
    new_indices = dep.index_transform(indices)
  File "../../Torch-Pruning/torch_pruning/dependency.py", line 148, in __call__
    if self.reverse==True:
RecursionError: maximum recursion depth exceeded in comparison

The network architecture is based on this paper. Here is a figure showing the details: image

Below is my test script that uses the model definition and pretrained weights from the model repo

# IMPORTS
import argparse
import nibabel as nib
import numpy as np
from datetime import datetime
import time
import sys
import os
import glob
import os.path as op
import logging
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms, utils

from scipy.ndimage.filters import median_filter, gaussian_filter
from skimage.measure import label, regionprops
from skimage.measure import label

from collections import OrderedDict
from os import makedirs

from models.networks import FastSurferCNN
import pandas as pd

# torch-pruning
sys.path.append('../../Torch-Pruning')
import torch_pruning as tp

def options_parse():
    """
    Command line option parser
    """
    parser = argparse.ArgumentParser()

    # Options for model parameters setup (only change if model training was changed)
    parser.add_argument('--num_filters', type=int, default=64,
                        help='Filter dimensions for DenseNet (all layers same). Default=64')
    parser.add_argument('--num_classes_ax_cor', type=int, default=79,
                        help='Number of classes to predict in axial and coronal net, including background. Default=79')
    parser.add_argument('--num_classes_sag', type=int, default=51,
                        help='Number of classes to predict in sagittal net, including background. Default=51')
    parser.add_argument('--num_channels', type=int, default=7,
                        help='Number of input channels. Default=7 (thick slices)')
    parser.add_argument('--kernel_height', type=int, default=5, help='Height of Kernel (Default 5)')
    parser.add_argument('--kernel_width', type=int, default=5, help='Width of Kernel (Default 5)')
    parser.add_argument('--stride', type=int, default=1, help="Stride during convolution (Default 1)")
    parser.add_argument('--stride_pool', type=int, default=2, help="Stride during pooling (Default 2)")
    parser.add_argument('--pool', type=int, default=2, help='Size of pooling filter (Default 2)')

    sel_option = parser.parse_args()

    return sel_option

def torch_prune(model,dummy_data,prune_type,prune_percent):

    print(f'compressing model with prune type: {prune_type}, sparsity: {prune_percent}')

    # 1. setup strategy (L1 Norm)
    strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy()

    # 2. build layer dependency for resnet18
    DG = tp.DependencyGraph()
    DG.build_dependency(model, example_inputs=dummy_data)

    # 3. get a pruning plan from the dependency graph.
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            print(f'module name: {name}')

            pruning_idxs = strategy(module.weight, amount=prune_percent) # or manually selected pruning_idxs=[2, 6, 9, ...]
            print(f'pruning_idxs: {pruning_idxs}')
            pruning_plan = DG.get_pruning_plan( module, tp.prune_conv, idxs=pruning_idxs )
            print(pruning_plan)

            # 4. execute this plan (prune the model)
            pruning_plan.exec()

def load_pretrained(pretrained_ckpt, params_model, model):
    model_state = torch.load(pretrained_ckpt, map_location=params_model["device"])
    new_state_dict = OrderedDict()

    # FastSurfer model specific configs
    for k, v in model_state["model_state_dict"].items():

        if k[:7] == "module." and not params_model["model_parallel"]:
            new_state_dict[k[7:]] = v

        elif k[:7] != "module." and params_model["model_parallel"]:
            new_state_dict["module." + k] = v

        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.eval()

    return model

if __name__ == "__main__":

    args = options_parse() 

    plane = "Axial"
    pretrained_ckpt = f'../checkpoints/{plane}_Weights_FastSurferCNN/ckpts/Epoch_30_training_state.pkl'

    # Put it onto the GPU or CPU
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Set up model for axial and coronal networks
    params_model = {'num_channels': args.num_channels, 'num_filters': args.num_filters,
                      'kernel_h': args.kernel_height, 'kernel_w': args.kernel_width,
                      'stride_conv': args.stride, 'pool': args.pool,
                      'stride_pool': args.stride_pool, 'num_classes': args.num_classes_ax_cor,
                      'kernel_c': 1, 'kernel_d': 1,
                      'model_parallel': False,
                      'device': device
                      }

    # Select the model
    model = FastSurferCNN(params_model)
    model.to(device)

    # Load pretrained weights
    model = load_pretrained(pretrained_ckpt, params_model, model)

    # Prune model
    dummy_data = torch.ones(1, 7, 256, 256)
    model = torch_prune(model, dummy_data, prune_type='L1', prune_percent=0.5)

    # Save pruned model
    # save_path = f'./{plane}_pruned.pth'
    # torch.save(model, save_path)

I will appreciate any help or suggestions! Thanks!

nikhil153 commented 3 years ago

HI @VainF - Just checking if you have any update for this issue. Thanks!

Puranjay-del-Mishra commented 2 years ago

This is not a bug with the cod per say. The recursion depth for python is 990 calls, it is set to avoid stackoverflow.

File "../../Torch-Pruning/torch_pruning/dependency.py", line 397, in _fix_denpendency_graph _fix_denpendency_graph(dep.broken_node, dep.handler, new_indices) [Previous line repeated 990 more times]

As you can see it got the RecursionError after 990 calls of _fix_dependency_graph

You can try and change the recursion depth for your python or alter the pruning percentage (pruning of 50% of the smallest weights is the reason why so many indices are being chosen for pruning , which leads to stackoverflow).

vinayak-sharan commented 2 years ago

If you have a huge model or combination of models, then it throws a recursion error. Would it be possible to implement a dependency graph without recursion? @VainF

VainF commented 2 years ago

If you have a huge model or combination of models, then it throws a recursion error. Would it be possible to implement a dependency graph without recursion? @VainF

Hi @vinayak-sharan , thank you for your advice. I will try to re-implement it in the next version.

VainF commented 2 years ago

Hi everyone, the non-recursive implementation of dependency graph has been uploaded. I will keep the issue open for further discussion!