Open sdreamforchen opened 2 years ago
你好,请问你说的什么意思,我没大看明白,我看代码我觉得代码写的没问题呀
是相符合的,没有不符。一开始被你也带跑了,跟着又回顾了下代码。重点看下述代码的注释:
for compactor_param, mask in compactor_mask_dict.items():
compactor_param.grad.data = mask * compactor_param.grad.data # equ 14, 对应公式第一项,Compactor参数损失梯度* Mask
lasso_grad = compactor_param.data * ((compactor_param.data ** 2).sum(dim=(1, 2, 3), keepdim=True) ** (-0.5))equ 14, 对应公式第二项,Group Lasso 梯度的计算。
compactor_param.grad.data.add_(resrep_config.lasso_strength, lasso_grad) # equ 14,对应公式第二项, Group Lasso梯度*\lambda
个人感觉,你应该是把compactor_param.grad.data
和 compactor_param.data
区分混淆了,注意.grad.
的存在。
for compactor_param, mask in compactor_mask_dict.items():####################单独设置compactor的梯度信息,加上lasso_grad梯度 compactor_param.grad.data = mask compactor_param.grad.data lasso_grad = compactor_param.data ((compactor_param.data 2).sum(dim=(1, 2, 3), keepdim=True) (-0.5))###########这个mask是乘以的loss第二项,和论文不同 compactorparam.grad.data.add(resrep_config.lasso_strength, lasso_grad)