torch / torch7

http://torch.ch
Other
8.97k stars 2.38k forks source link

nn.MaskedSelect, nn.Dropout #866

Open evcu opened 7 years ago

evcu commented 7 years ago

Hi,

I am a M.Sc. student and I am implementing network pruning/compression from the Learning both Weights and Connections for Efficient Neural Networks paper as my final project. I am using Torch7 and Torchnet to implement it. I am planning to first add a dropout like fixed layer in between every layer(to imitate pruning). Is there any module to do that? nn.MaskedSelect do select input but, outputs a vector. nn.Dropout has changing mask. Do you think I miss something? Would be helpful to implement?

Kind regards

nicholas-leonard commented 7 years ago

Yes that would be useful. Please implement and send PR to nn.

fmassa commented 7 years ago

Note that there is a very simple way of doing this pruning in torch, which is certainly going to be simpler than having an extra layer to be added after each layer. Once you have the parameters via getParameters(), all you need do is to mask it after every parameter update, something like

-- do sgd
optim.sgd(...)

-- apply pruning with mask that was
-- previously computed
if apply_mask then
  params[mask] = 0
end
evcu commented 7 years ago

@fmassa Thanks for the comment. I am doing something similar to that. But an important part of the pruning is retraining. I am using torchnet. As a quick solution I've implemented it by adding a function like

function(state)
  for k,v in pairs(self.masks) do
    state.network:get(k).gradWeight:cmul(v)
    if verbose then
    --print('Layer'..k..': Gradprunned\n') 
    end
  end
end

which simply deletes the accumulated grad for the pruned connections.

I am not sure what is the best way of implementing it, since without sparse-tensor's it is going to just simulate the situation. But I think implementing these necessary backward and forward functions is would be nice.

It might be my first ever contribution and I am exited about that. Any recommendation, help I would appreciate. I will send the PR probably today, thanks @nicholas-leonard

fmassa commented 7 years ago

@evcu the solution I proposed allows one to do fine-tuning without problem, conserving the zeroed weights. I implemented this paper some months ago and it worked like a charm with one 3 extra lines of code. Note that even if you zero the gradients, due to weight decay in sgd optimization, your parameters will still change. That's why I think the easiest solution is something in the lines I mentioned. A simple (untested) solution using torchnet is:

engine.hooks.onUpdate = function(state)
  state.params[state.mask] = 0
end

I'm not sure a dedicated layer is going to add much here, but it might be a matter of personal taste.

PS: maybe I misunderstood your message. Do you mean implementing dedicated functions for performing efficient convolutions with sparse weights?

evcu commented 7 years ago

For my project I am doing something quite similar to that. I set the weights of the pruned connections to zero. This simulates pruning for forward. Then if I retrain I need to reset accumulated gradient values for the pruned weights before the update. I used hooks.onBackward since it is just after backward call('onUpdate' would be too late). I tested this and it worked.

I agree that an additional layer wouldn't mean a lot without sparse tensors but I believe it is something eventually be needed and I think creating a new module for each sparse-version might not necesarry, therefore I've added setter/remover functions to the original modules nn.Linear and nn.SpatialConvolution. Such that one can just call nn.Linear:setMask(mask) to prune the connections given by mask and keep retraining or pruning. Since this is my first pull request I am not sure I did it right. I tested it as advised in a itorch notebook. Should I also commit that? #1073

evcu commented 7 years ago

For my project I am doing something quite similar to that. I set the weights of the pruned connections to zero. This simulates pruning for forward. Then if I retrain I need to reset accumulated gradient values for the pruned weights before the update. I used hooks.onBackward since it is just after backward call('onUpdate' would be too late). I tested this and it worked.

I agree that an additional layer wouldn't mean a lot without sparse tensors but I believe it is something eventually be needed and I think creating a new module for each sparse-version might not necesarry, therefore I've added setter/remover functions to the original modules nn.Linear and nn.SpatialConvolution. Such that one can just call nn.Linear:setMask(mask) to prune the connections given by mask and keep retraining or pruning. Since this is my first pull request I am not sure I did it right. I tested it as advised in a itorch notebook. Should I also commit that and I realized that this is the wrong repo to discuss I think. The pull request no is 1073