Open ChenMnZ opened 3 years ago
I solved the problem with the following code
temp = new_mask.flatten()
i = 0
for index,m in enumerate(idx):
if not temp[m]:
i += 1
if i == total_regrowth:
break
new_mask.data.view(-1)[idx[:index]] = 1.0
Thanks for posting the solution to the problem! I am currently not quite understanding what was going on. Is the code that you provided a general improvement to do the same thing or is it just useful for your ablation experiments?
Hi, thank you for your great work. Today, I want to do an ablation experience on your work. I just modified the
momentum_growth
funtion. fromy, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
toy, idx = torch.sort(torch.abs(grad).flatten(), descending=False)
I take the experience with the command:python main.py --growth momentum --prune magnitude --redistribution momentum --prune-rate 0.2 --density 0.1 --data cifar10 --model vgg-c
I foud that the final sparsity will drop to 0.073. I read the source code and find thatmomentum_growth
funtion can't growth enough weight because it didn't tell weather the mask was 0 befor growth. You deal this problem with theadjusted_growth
. And I wonde that why this method work in your origin function but can't work in my ablantion experience.