piEsposito / blitz-bayesian-deep-learning

A simple and extensible library to create Bayesian Neural Network layers on PyTorch.
GNU General Public License v3.0
921 stars 106 forks source link

Initializing net parameters #39

Closed LukasWFurtner closed 4 years ago

LukasWFurtner commented 4 years ago

Hi, is there a possibility to initialize the mu´s of a net that is using blitz layers with the parameters of its deterministic equivalent? To get a better understanding I build a net with only one weight and one bias, thus the weight can be interpreted as the slope and the bias as the intercept of a linear function. However, setting the mu of the bias and weight seems to only work for a frozen model. Setting blitz net mu´s

import torch
import torch.nn as nn
from blitz.modules import BayesianLinear
from blitz.utils import variational_estimator
import matplotlib.pyplot as plt
import numpy as np

@variational_estimator
class bayesian_Net(nn.Module):

    def __init__(self, freeze):
        super(bayesian_Net, self).__init__()
        self.bl1 = BayesianLinear(1, 1, bias=True, freeze=freeze)

    def forward(self, x):
        x = self.bl1(x)
        return x

unfrozen_net = bayesian_Net(freeze=False)
frozen_net = bayesian_Net(freeze=True)

slope = -0.5
intercept = 1

unfrozen_net.bl1.weight_mu = torch.nn.Parameter(torch.Tensor([[slope]]), requires_grad=True)
unfrozen_net.bl1.bias_mu = torch.nn.Parameter(torch.Tensor([intercept]), requires_grad=True)

frozen_net.bl1.weight_mu = torch.nn.Parameter(torch.Tensor([[slope]]), requires_grad=True)
frozen_net.bl1.bias_mu = torch.nn.Parameter(torch.Tensor([intercept]), requires_grad=True)

x = np.zeros((10, 1))
for k in range(1, 10):
    x[k] = k
x = torch.FloatTensor(x)
#plot
for k in range(100):
    plt.plot(x, unfrozen_net(x).detach().numpy(), 'r')
    plt.plot(x, frozen_net(x).detach().numpy(), 'k')
piEsposito commented 4 years ago

Hello Lukas, If you have a network with weight as the same shape as the variational layer you are using from blitz, you can, of course, set the blitz mu of the weights as that deterministic one, using built-in torch methods.

You can also train a frozen model and then unfreeze it (using the mu as something alike to weight priors). On that case, you may want to set your rho init parameter as small as you can, so you can reduce the variance of your model and calibrate it on the variational training.

LukasWFurtner commented 4 years ago

Thank you for the fast answer, but I still don´t understand it. The frozen and unfrozen net were both assigned the same bias_mu and weight_mu. That is why in my opinion the prediction of the frozen net should approximately be the mean of the 100 different predictions from the unfrozen net. All the 100 predictions with the unfrozen net after the weight and bias is assigned have sampled a positive weight, otherwise the slope would not be positive in all the cases. However, the mean weight/slope of the unfrozen net is negative (-0.5). The net seems to make the calculations with the initial weight and bias. Changing the posterior_rhoinit did not change the variance in the predictions. ![Setting blitz net mu´s](https://user-images.githubusercontent.com/63885140/83432156-c052d800-a438-11ea-8079-195aef580b4b.png)

piEsposito commented 4 years ago

Oh, I see what you're talking about. What might be happening is that your manual weight setting might be changing the weights on the BayesianModule object, but not on the weight samplers inside of it, so when you try to forward the unfrozen layer, it keeps the old ones.

(I make the weight samplers as BayesianModules and make them share the weight with the actual layer so it is easily integrable with CUDA.)

When I change the weights as

a = BayesianLinear(1,1)
a.weight_mu.data = torch.Tensor(1,1).normal_()

rather then

a.weight_mu.data = nn.Parameter(torch.Tensor(1,1).normal_())

It works.

LukasWFurtner commented 4 years ago

It works! Thanks again for the fast response. I think you were totally right that the changes were not passed on to the samplers of the weight and bias. I just needed to initialize the samplers as well. This also works with nn.Parameter.

unfrozen_net.bl1.weight_sampler = GaussianVariational(unfrozen_net.bl1.weight_mu, unfrozen_net.bl1.weight_rho)
unfrozen_net.bl1.bias_sampler = GaussianVariational(unfrozen_net.bl1.bias_mu, unfrozen_net.bl1.bias_rho)

Setting blitz net mu´s__

For completeness, the working code:

import torch
import torch.nn as nn
from blitz.modules import BayesianLinear
from blitz.modules.weight_sampler import GaussianVariational
from blitz.utils import variational_estimator
import matplotlib.pyplot as plt
import numpy as np

@variational_estimator
class bayesian_Net(nn.Module):

    def __init__(self, freeze):
        super(bayesian_Net, self).__init__()
        self.bl1 = BayesianLinear(1, 1, bias=True, freeze=freeze)

    def forward(self, x):
        x = self.bl1(x)
        return x

unfrozen_net = bayesian_Net(freeze=False)
frozen_net = bayesian_Net(freeze=True)

print('Before setting weight and bias manually: ' )
print('frozen_net.bl1.weight_mu:', frozen_net.bl1.weight_mu.item())
print('frozen_net.bl1.bias_mu:', frozen_net.bl1.bias_mu.item())
print('unfrozen_net.bl1.weight_mu:', unfrozen_net.bl1.weight_mu.item())
print('unfrozen_net.bl1.bias_mu:', unfrozen_net.bl1.bias_mu.item())

x = np.zeros((10, 1))
for k in range(1, 10):
    x[k] = k
x = torch.FloatTensor(x)

#plot
plt.plot(x, unfrozen_net(x).detach().numpy(), 'r', label='unfrozen')
plt.plot(x, frozen_net(x).detach().numpy(), 'k', label='frozen')
for k in range(100):
    plt.plot(x, unfrozen_net(x).detach().numpy(), 'r')
    plt.plot(x, frozen_net(x).detach().numpy(), 'k')
plt.ylim((-1, 1))
plt.legend()
plt.show()

slope = -0.5
intercept = 1
unfrozen_net.bl1.weight_mu = torch.nn.Parameter(torch.Tensor([[slope]]), requires_grad=True)
unfrozen_net.bl1.bias_mu = torch.nn.Parameter(torch.Tensor([[intercept]]), requires_grad=True)
frozen_net.bl1.weight_mu = torch.nn.Parameter(torch.Tensor([[slope]]), requires_grad=True)
frozen_net.bl1.bias_mu = torch.nn.Parameter(torch.Tensor([[intercept]]), requires_grad=True)

frozen_net.bl1.weight_sampler = GaussianVariational(frozen_net.bl1.weight_mu, frozen_net.bl1.weight_rho)
frozen_net.bl1.bias_sampler = GaussianVariational(frozen_net.bl1.bias_mu, frozen_net.bl1.bias_rho)
unfrozen_net.bl1.weight_sampler = GaussianVariational(unfrozen_net.bl1.weight_mu, unfrozen_net.bl1.weight_rho)
unfrozen_net.bl1.bias_sampler = GaussianVariational(unfrozen_net.bl1.bias_mu, unfrozen_net.bl1.bias_rho)

print('After setting weight and bias manually: ' )
print('frozen_net.bl1.weight_mu:', frozen_net.bl1.weight_mu.item())
print('frozen_net.bl1.bias_mu:', frozen_net.bl1.bias_mu.item())
print('unfrozen_net.bl1.weight_mu:', unfrozen_net.bl1.weight_mu.item())
print('unfrozen_net.bl1.bias_mu:', unfrozen_net.bl1.bias_mu.item())

#plot
plt.plot(x, unfrozen_net(x).detach().numpy(), 'r', label='unfrozen')
plt.plot(x, frozen_net(x).detach().numpy(), 'k', label='frozen')
for k in range(100):
    plt.plot(x, unfrozen_net(x).detach().numpy(), 'r')
    plt.plot(x, frozen_net(x).detach().numpy(), 'k')
plt.ylim((-1, 1))
plt.legend()
plt.show()