DingXiaoH / ResRep

ResRep: Lossless CNN Pruning via Decoupling Remembering and Forgetting (ICCV 2021)
MIT License
288 stars 36 forks source link

mask和论文不符吗? #20

Open sdreamforchen opened 2 years ago

sdreamforchen commented 2 years ago

for compactor_param, mask in compactor_mask_dict.items():####################单独设置compactor的梯度信息,加上lasso_grad梯度 compactor_param.grad.data = mask compactor_param.grad.data lasso_grad = compactor_param.data ((compactor_param.data 2).sum(dim=(1, 2, 3), keepdim=True) (-0.5))###########这个mask是乘以的loss第二项,和论文不同 compactorparam.grad.data.add(resrep_config.lasso_strength, lasso_grad)

if not if_accum_grad:
    if gradient_mask_tensor is not None:##################gradient_mask_tensor一直为None
        for name, param in net.named_parameters():
            if name in gradient_mask_tensor:
                param.grad = param.grad * gradient_mask_tensor[name]
    optimizer.step()###############每次只有第二项会mask
    optimizer.zero_grad()
acc, acc5 = torch_accuracy(pred, label, (1,5))
Lutong-Qin commented 2 years ago

你好,请问你说的什么意思,我没大看明白,我看代码我觉得代码写的没问题呀

yannqi commented 1 year ago

是相符合的,没有不符。一开始被你也带跑了,跟着又回顾了下代码。重点看下述代码的注释:

for compactor_param, mask in compactor_mask_dict.items():
        compactor_param.grad.data = mask * compactor_param.grad.data  # equ 14, 对应公式第一项,Compactor参数损失梯度* Mask
        lasso_grad = compactor_param.data * ((compactor_param.data ** 2).sum(dim=(1, 2, 3), keepdim=True) ** (-0.5))equ 14, 对应公式第二项,Group Lasso 梯度的计算。
        compactor_param.grad.data.add_(resrep_config.lasso_strength, lasso_grad)  # equ 14,对应公式第二项, Group Lasso梯度*\lambda

个人感觉,你应该是把compactor_param.grad.datacompactor_param.data 区分混淆了,注意.grad.的存在。