rahulvigneswaran / Lottery-Ticket-Hypothesis-in-Pytorch

This repository contains a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks" by Jonathan Frankle and Michael Carbin that can be easily adapted to any model/dataset.
325 stars 91 forks source link

Freeze pruned weights method not efficient #10

Open guoyuntu opened 3 years ago

guoyuntu commented 3 years ago

In 'main.py' line 257 - 262, the author used the following codes to freeze the pruned weights:

for name, p in model.named_parameters():
        if 'weight' in name:
            tensor = p.data.cpu().numpy()
            grad_tensor = p.grad.data.cpu().numpy()
            grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
            p.grad.data = torch.from_numpy(grad_tensor).to(device)

which causes a heavy burden for CPU2GPU I/O. I will recommend conducting the freezing operation on GPU directly, the following codes helps:

    for name, p in model.named_parameters():
        if 'weight' in name:
            tensor = p.data
            grad_tensor = p.grad
            grad_tensor = torch.where(tensor.abs() < EPS, torch.zeros_like(grad_tensor), grad_tensor)
            p.grad.data = grad_tensor
bainro commented 2 years ago

A batch size of 200 on mnist + lenet5 went from 12 seconds per epoch to 6 with your changes. Thank you!