Zj-BinXia / SSL

This project is the official implementation of 'Structured Sparsity Learning for Efficient Video Super-Resolution', CVPR2023
98 stars 6 forks source link

code confusion 2 #4

Closed zxd-cqu closed 1 year ago

zxd-cqu commented 1 year ago

Hi,Sorry to bother you again. In this function(basicsr/pruner/SSL_pruner.py:def _apply_reg(self)), my understanding is to add additional gradients to the scaling factors for further sparsity. Should the line m.act_scale_pre.grad += reg_pre[:, 0].view(1, -1, 1, 1) * m.act_scale be changed tom.act_scale_pre.grad += reg_pre[:, 0].view(1, -1, 1, 1) * m.act_scale_pre? My reasoning is that this is necessary to apply sparsity to m.act_scale_pre.

    def _apply_reg(self):
        for name, m in self.model.named_modules():
            if name in self.layers and self.pr[name] > 0:
                reg = self.reg[name] # [N, C]
                m.act_scale.grad += reg[:, 0].view(1,-1,1,1) * m.act_scale

                if hasattr(m, 'act_scale_pre'):
                    reg_pre = self.reg_pre[name]
                    #change this or not:
                    m.act_scale_pre.grad += reg_pre[:, 0].view(1, -1, 1, 1) * m.act_scale 
                # bias = False if isinstance(m.bias, type(None)) else True
                # if bias:
                #     m.bias.grad += reg[:, 0] * m.bias
Zj-BinXia commented 1 year ago

Yes, you can make the modification like that.

zxd-cqu commented 1 year ago

thanks!