HolyBayes / pytorch_ard

Pytorch implementation of Variational Dropout Sparsifies Deep Neural Networks
MIT License
83 stars 16 forks source link

Local Reparameterization Trick #7

Closed MahdiGilany closed 3 years ago

MahdiGilany commented 3 years ago

Hi,

As you know the Sparse VD paper mentions that we use both Local Reparameterization and Additive Noise Reparameterization to reduce the variance. I have noticed that Local Reparameterization Trick in Variational Dropout paper is not implemented here as opposed to the original code for Sparse VD. So I think there is a mistake here in the forward passes which needs to change:

def forward(self, input):
    if self.training:
        epsilon = self.weight.new(self.weight.shape).normal_()
        W = self.weight + epsilon * torch.exp(self.log_sigma2 / 2)
    else:
        W = self.weights_clipped
    return F.linear(input, W) + self.bias

Should change to:

self.weight = Parameter(torch.Tensor(in_features, out_features)) ## I changed the shape of self.weight

def forward(self, input):
    if self.training:
        W_mu = input.matmul(self.weight)
        std_w = torch.exp(self.log_sigma2)
        W_std = torch.sqrt((input.pow(2)).matmul(std_w.pow(2)))

        epsilon = self.W_std.new(self.W_std.shape).normal_()
        output = W_mu + W_std * epsilon
        output += self.bias
    else:
        W = self.weights_clipped
        output = W + self.bias
    return output
HolyBayes commented 3 years ago

Thank you for your comment! Fixed