salesforce / awd-lstm-lm

LSTM and QRNN Language Model Toolkit for PyTorch
BSD 3-Clause "New" or "Revised" License
1.96k stars 488 forks source link

Variational WeightDrop not disabled at evaluation time #72

Open hglaude opened 6 years ago

hglaude commented 6 years ago

Hi,

In weight_drop.py

    def _setweights(self):
        for name_w in self.weights:
            raw_w = getattr(self.module, name_w + '_raw')
            w = None
            if self.variational:
                mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
                if raw_w.is_cuda: mask = mask.cuda()
                mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
                w = mask.expand_as(raw_w) * raw_w
            else:
                w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
            setattr(self.module, name_w, w)
mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)

should be

mask = torch.nn.functional.dropout(mask, p=self.dropout, training=self.training)

Best