VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.64k stars 329 forks source link

PixelShuffle #420

Open AvrahamRaviv opened 4 weeks ago

AvrahamRaviv commented 4 weeks ago

Hi, torch.nn.PixelShuffle is change the number of channels, which is not treated well using torch pruning. For example, I have layer with shape of 1165761024, and by using PixelShuffle the output is 1411522048. Using torch pruning, dep graph think it is layer with same output channels, which cause to two errors:

  1. Number of PixelShuffle's output channels should be reduced/increased by the relevant factor.
  2. Number of next layer's input channels should be reduced/increased by the relevant factor as well. How can I handling it? Thanks!
angelinimattia commented 1 week ago

I Implemented it for my thesis. Usually Intrinsically the PixelShuffle is never pruned, it only expect an input depth as such I bundled the PixelShuffle operator together with the sub-pixel convolution usually putted before it for the SR operation. At this point the operation becomes the following: which are the output channel to prune in the conv layer given an input channel. You can find my implementation here https://github.com/MaGiiK02/sr_structured_pruning/blob/main/pruners/UpsamplePruner.py where the Upsample block represent the standard Upsample operation

  1. Conv2D(n, n*(s^2)), where s is the scaling factor
  2. pixelshuffle layer following it
AvrahamRaviv commented 1 week ago

@angelinimattia Wow, thanks for the reply! It’s great to see others have tackled this implementation. I’ll be sure to integrate it into my code soon.

angelinimattia commented 1 week ago

@AvrahamRaviv Feel free to ask me until I remember what I did, since I submitted my master thesis few weeks ago.