salesforce / awd-lstm-lm

LSTM and QRNN Language Model Toolkit for PyTorch
BSD 3-Clause "New" or "Revised" License
1.96k stars 488 forks source link

Weight drop code masking the same "raw" weight? #8

Closed jerrybai1995 closed 7 years ago

jerrybai1995 commented 7 years ago

Hey,

I was inspecting the weight drop (variant of dropconnect) code and I found it a bit confusing (https://github.com/salesforce/awd-lstm-lm/blob/master/weight_drop.py#L34):

for name_w in self.weights:
      raw_w = getattr(self.module, name_w + '_raw')
      w = None
      if self.variational:
          mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
          if raw_w.is_cuda: mask = mask.cuda()
          mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
          w = mask.expand_as(raw_w) * raw_w
      else:
          w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
      setattr(self.module, name_w, w)

In every iteration the raw_w you get from name_w + '_raw' is the same, isn't it? Because you only setattr to name_w (e.g. weight_hh_l0) at the end. So every time the dropout mask operates on the same raw weight matrix...

Or maybe I just overlooked something. Can someone help me understand this?

Thanks!

Smerity commented 7 years ago

Hey! ^_^ You're correct locally but missing a global point :)

Locally, the weight drop code you're looking at doesn't modify the raw weight raw_w.

Globally however the raw weight raw_w is being modified - it's being updated by normal SGD.

In the training loop, we run the forward pass, which sets a new dropped out weight, and calculates a loss. When we call loss.backward() we calculate gradients for all our weights, including the raw weight raw_w which is used to create the dropped out weight w. Then we apply the update, usually an SGD step, which modifies raw_w.

As such, weight drop operates on the same raw weight raw_w each time but the values in that weight change during training.

Hope that helps!

jerrybai1995 commented 7 years ago

Ah, missed that (not really familiar with pytorch parameter setting). So just to double check, when doing the computation (e.g. Wx+b) in forward direction, the weight used is w (masked), but when updating/learning, the weight got changed is raw_w since it's registered?

Thanks for the prompt reply!

Smerity commented 7 years ago

Correct. The masked / dropped out weight w is only a temporary result of an operation on the raw weight raw_w so it's what is modified when backprop is applied :)