Open adnan1306 opened 1 year ago
In line286 of main.py, it should be: new_mask = np.where(abs(tensor) > percentile_value, 0, mask[step]) so that smaller weights are set to zero.
new_mask = np.where(abs(tensor) > percentile_value, 0, mask[step])
https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/blob/34a8c9678406a1c7dd0fec4c9f0d25d017be55fb/main.py#L286
No the code is correct, check the np.where docs. 0 is set when tensor is smaller than the cutoff value.
In line286 of main.py, it should be:
new_mask = np.where(abs(tensor) > percentile_value, 0, mask[step])
so that smaller weights are set to zero.https://github.com/rahulvigneswaran/Lottery-Ticket-Hypothesis-in-Pytorch/blob/34a8c9678406a1c7dd0fec4c9f0d25d017be55fb/main.py#L286