wenwei202 / caffe

Caffe for Sparse and Low-rank Deep Neural Networks
Other
375 stars 134 forks source link

A pytorch re-implementation of Structured Sparsity Learning #36

Open zjykzj opened 3 years ago

zjykzj commented 3 years ago

hi @wenwei202 , Thank you for your sharing. Based on your paper, I implemented ZJCV/SSL using pytorch, including train-prune-finetuing for VGGNet and ResNet

From the experimental results, we can see that the pruning effect is very good. Thank you again for SSL

wenwei202 commented 3 years ago

@zjykzj it's great that you are reimplementing it in pytorch after the old days of using caffe.

Bojue-Wang commented 1 year ago

hi @wenwei202 , Thank you for your sharing. Based on your paper, I implemented ZJCV/SSL using pytorch, including train-prune-finetuing for VGGNet and ResNet

  • For VGGNet, I realized filter/channel/filter_and_channel pruning;
  • For ResNet, I realized depth pruning.

From the experimental results, we can see that the pruning effect is very good. Thank you again for SSL

把代码里少的那个根号加上重新传一遍吧。这个根号项就是group的全部精髓。要是没有这个根号把一个卷积核装进一个篮子里,这个pytorch版的代码只相当于手动计算了一下没有开平方的L2范数的惩罚项。丢掉了SSL的核心。

P.S. 又踩一坑

zjykzj commented 1 year ago

@caikengxiaonengshou 能否具体解释一下开根号前后的差异性?另外有相关的试验结果木有?

Bojue-Wang commented 1 year ago

Lasso(在没有group的情况下,特指L1范数),岭回归是L2范数。但是Group Lasso,被分出的子分组必须对组内使用L2范数。假设有a, b, c, d, e, f六个参数,如果分组同时使用L1范数, abc分一组,def分一组,(|a| + |b| + |c|)+(|d| + |e| + |f|)= |a| + |b| + |c| + |d| + |e| + |f| =直接对六个参数使用L1范数,没有分组效果。因此Group Lasso必须使用L2范数进行分组。

分组同时使用L2范数,是先计算组内L2范数, 再对各组L2范数进行求和。比如我有10个卷积核,每个卷积核的参数分成一组。那么Group Lasso这部分Loss应该有10个根号,求导的时候每个根号下的,来自同一个卷积核的参数,共享同一个外层导数作为系数。正是这个系数,指导这一个卷积核内的参数是否逐渐共同逐渐向0靠近。从而进行结构化稀疏。

如果不加这个根号,那么按照你代码里的公式。假如我有十个卷积核。那么就是相当于将这10个卷积核里的所有参数的平方求和。如果加上一个根号套住这个和,就彻底变成了L2范数。

试验我是这样做的,我自己写了一个L2范数的函数A,不带根号的。然后torch.rand一个[2, 3, 2, 2]的张量。用L2范数的函数A, 和代码中的def group_lasso_by_filter_or_channel(), 打印出来的两个值是一样的。相当于没有分组。 image

@caikengxiaonengshou 能否具体解释一下开根号前后的差异性?另外有相关的试验结果木有?

zjykzj commented 1 year ago

@caikengxiaonengshou 嗯嗯 我先研究一下,相关的实验估计也得重做。欢迎PR