VainF / Torch-Pruning

[CVPR 2023] Towards Any Structural Pruning; LLMs / SAM / Diffusion / Transformers / YOLOv8 / CNNs
https://arxiv.org/abs/2301.12900
MIT License
2.63k stars 329 forks source link

loss stop reducing after pruning. #379

Open liuxiangchao369 opened 4 months ago

liuxiangchao369 commented 4 months ago

the loss stop reducing after pruning (only one step). Is it any wrong with my code?

imp = tp.importance.MagnitudeImportance(p=2)
ignored_layers = []
iterative_steps = 100
example_inputs = torch.randn(1, 1, FACE_SHAPE[0], FACE_SHAPE[1])
for m in net.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 5:
        print(m)
        ignored_layers.append(m)  # DO NOT prune the final classifier!
pruner = tp.pruner.MagnitudePruner(model=net,
                                   example_inputs=example_inputs,
                                   pruning_ratio=0.01,
                                   importance=imp,
                                   iterative_steps=iterative_steps,
                                   ignored_layers=ignored_layers,
                                   )

# ...
for epoch in range(max_epochs):
    self.train(model=net, optimizer=optimizer, data_loader=train_loader, device=device, epoch=epoch,
                       pruner=pruner)
#...
    valid_loss, f1 = self.valid(model=net, valid_loader=valid_loader, device=device, epoch=epoch)
                if f1 > 0.9 and steps > 0:
                    pruner.step()                   
                    steps -= 1
    def train(self, model, optimizer, data_loader, device, epoch, pruner):
        """
        :param pruner 
        :param device:
        :param data_loader:
        :param optimizer:
        :param model:
        :type epoch: int
        """
        criterion = CustomLoss()
        model.train()

        sum_loss = 0
        for batch_idx, (X, y) in enumerate(data_loader, 0):

            X = X.to(device)
            y = y.to(device)
            output = model(X)
            optimizer.zero_grad()
            # print(output)
            loss = criterion(output, y)
            pruner.regularize(model, loss)
            loss.backward()

            optimizer.step()
            sum_loss += loss.item()
            self.global_step += 1
            self.trainWriter.add_scalar("loss", loss.item(), self.global_step)
        self.trainWriter.add_scalar("train loss", sum_loss / len(data_loader), epoch + 1)

        lr = optimizer.param_groups[0]['lr']
        self.trainWriter.add_scalar("lr", lr, epoch + 1)
tahayf commented 1 month ago

It doesn't seem any wrong in pruning part. Could you please share CustomLoss function?