siyuan0 / pytorch_model_prune

Simple python code to prune pytorch model
MIT License
21 stars 4 forks source link

Is there any speedup? #5

Open chamecall opened 2 years ago

chamecall commented 2 years ago

Hey. Thanks for your work first! I'm just wondering if we really have any speedup with the zero_padding approach cause in this case we restore our layer to the same previous number of parameters, don't we? So do we receive any speedup using classic dl frameworks like Pytorch and classic GPUs?

chamecall commented 2 years ago

Did I get it right that we apply zero_padding only for shortcut connection channel mismatches and for others subsequent layer pairs l and l+1 we just remove output channels in layer l and according input channels in layer l+1?

chamecall commented 2 years ago

it's pretty weird, we have more than 2 times less parameters but time is almost the same:

import torchvision.models as models
from torchvision.models.resnet import BasicBlock
import torch
from time import time

dummy = torch.rand(10, 3, 160, 160)

backbone = models.resnet18(pretrained=True)
s = time()
for i in range(100):
    backbone(dummy)
e = time()
print(e-s)

prune_model(backbone)
# print(backbone)
# print(backbone(dummy).shape)

s = time()
for i in range(100):
    backbone(dummy)
e = time()
print(e-s)

I guess that's really because of restoring the same number of parameters with zero_padding

siyuan0 commented 1 year ago

Good observation! This project was done in 2019 as a proof of concept for the retention of accuracy after pruning large number of channels in a typical Convolutional Neural Network, as such there was no real reduction in operations performed since the zero padding simply replaces the pruned weights with zeroes. I'm happy to say that such an idea has been accepted and greatly expanded upon by the wider ML community since then, and there are even proper libraries by PyTorch (https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) to implement this for networks that aren't just limited to convolutions. On the point of performance, I imagine that the real magic will happen when the hardware specially designed for ML (eg. Google's TPUs) implements properly create operation sets to handle pruned networks (perhaps implicit handling of tensor connections between layers during compilation?). Also, I wish to apologise for long delay of response. I probably won't be checking this repo as much now that there are proper libraries to handle pruning, but it's still nice to see that someone found the content here useful.