VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.59k stars 318 forks source link

Pruning with concatenation failing #146

Open WAvery4 opened 1 year ago

WAvery4 commented 1 year ago

Not sure if this is intended behavior, but it looks like there might be an issue with concatenation based on the following test.

Code:

class TestModule(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_dim, in_dim, 1),
            nn.BatchNorm2d(in_dim),
            nn.GELU(),
            nn.Conv2d(in_dim, in_dim, 1),
            nn.BatchNorm2d(in_dim)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_dim * 2, in_dim, 1),
            nn.BatchNorm2d(in_dim)
        )

    def forward(self, x):
        x = self.block1(x)
        x = torch.cat([x, x], dim=1)
        x = self.block2(x)
        return x

model = TestModule(512)       

pruner = tp.pruner.MagnitudePruner(
    model,
    dummy_input,
    importance=tp.importance.MagnitudeImportance(p=2),
    iterative_steps=6,
    ch_sparsity=0.75,
    ignored_layers=ignored_layers
)

dummy_input = torch.randn(1, 512, 7, 7)

for step in range(6):
    pruner.step()

    model = model.eval()
    model(dummy_input)

Expected Behavior: After the first pruning stage, Block 2 will have input channels 896 and output channels 448.

Actual Behavior:

RuntimeError: Given groups=1, weight of size [448, 960, 1, 1], expected input[1, 896, 7, 7] to have 960 channels, but got 896 channels instead
VainF commented 1 year ago

Realy thanks for the feedback. Will fix it!

VainF commented 1 year ago

Hi @WAvery4, This bug was fixed in the latest commit. Thank you!

The original model:

Net(
  (block1): Sequential(
    (0): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GELU(approximate=none)
    (3): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block2): Sequential(
    (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

The pruned model:

Net(
  (block1): Sequential(
    (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GELU(approximate=none)
    (3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block2): Sequential(
    (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
WAvery4 commented 1 year ago

No problem! Thank you for looking into it!

WAvery4 commented 1 year ago

Hi, I was testing out 1.1.6, and I am still getting the same issue.

Code:

class TestModule(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_dim, in_dim, 1),
            nn.BatchNorm2d(in_dim),
            nn.GELU(),
            nn.Conv2d(in_dim, in_dim, 1),
            nn.BatchNorm2d(in_dim)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_dim * 2, in_dim, 1),
            nn.BatchNorm2d(in_dim)
        )

    def forward(self, x):
        x = self.block1(x)
        x = torch.cat([x, x], dim=1)
        x = self.block2(x)
        return x

model = TestModule(512)

dummy_input = torch.randn(1, 512, 7, 7)

pruner = tp.pruner.MagnitudePruner(
    model,
    dummy_input,
    importance=tp.importance.MagnitudeImportance(p=2),
    iterative_steps=6,
    ch_sparsity=0.75,
    ignored_layers=None
)

for step in range(6):
    pruner.step()

    model = model.eval()
    model(dummy_input)

Expected Behavior: After the first pruning stage, Block 2 will have input channels 896 and output channels 448.

Actual Behavior:

RuntimeError: Given groups=1, weight of size [448, 960, 1, 1], expected input[1, 896, 7, 7] to have 960 channels, but got 896 channels instead

Version:

Package            Version
------------------ ------------
certifi            2022.12.7
charset-normalizer 3.1.0
colorama           0.4.6
coremltools        6.2
filelock           3.10.7
huggingface-hub    0.13.3
idna               3.4
mpmath             1.3.0
numpy              1.24.2
packaging          23.0
Pillow             9.5.0
pip                23.0.1
protobuf           3.20.3
PyYAML             6.0
requests           2.28.2
setuptools         65.6.3
sympy              1.11.1
timm               0.6.13
torch              1.12.1+cu116
torch-pruning      1.1.6
torchaudio         0.12.1+cu116
torchvision        0.13.1+cu116
tqdm               4.65.0
typing_extensions  4.5.0
urllib3            1.26.15
wheel              0.38.4
wincertstore       0.2
WAvery4 commented 1 year ago

It looks like the change was removed in response to #147 in commit here. The version update says the concat bug was fixed, though (this might have been in reference to #147 's issue). In that case, I will leave this issue open for now.

VainF commented 1 year ago

Yes, we removed this patch in the latest commit due to some issues it caused during YOLO pruning. We are currently working on finding a more appropriate solution to address the issue.

WAvery4 commented 1 year ago

I was reading through dependency.py and wanted to test if this bug occurred when x is not duplicated. I have found that the bug also occurs when torch.cat([x, x_i], dim=1) is called where x_i is derived from x in an auxiliary path . Not sure if this helps.

WAvery4 commented 1 year ago

Have there been any updates regarding this issue? I am very interested in using this library, but cannot use it until this bug is resolved. Thanks!

lixinghe1999 commented 1 year ago

Have there been any updates regarding this issue? I am very interested in using this library, but cannot use it until this bug is resolved. Thanks!

Actually face similar issue when I concat the two tensors, however, I find the example of multi-input-multi-ouput also have concat. I guess it is the problem of pruner. I am still working on it, looks forward to others' insights.

lacie-life commented 3 months ago

Have there been any updates regarding this issue? Thank you