kumar-shridhar / PyTorch-BayesianCNN

Bayesian Convolutional Neural Network with Variational Inference based on Bayes by Backprop in PyTorch.
MIT License
1.42k stars 323 forks source link

use bayes for segementation #65

Open hannah-saber opened 3 years ago

hannah-saber commented 3 years ago

Hi, afer reading your paper, I think bayesian can be used for segmentation, then I add bayes into CNN by resampling weight and bias in conv3d, but the result is bad. If it is possible, could you help me to see my code?

yours, Hannah

hannah-saber commented 3 years ago

`class bayesianConv(nn.Module): def init(self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=1, bias=True, n=1024, kwargs): super(bayesianConv, self).init(kwargs) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.use_bias = bias self.n = n

    #初始化bias的mu和rho
    self.b_mu = torch.zeros(self.out_channels)
    self.b_rho = torch.zeros(self.out_channels)
    self.b = torch.zeros(self.out_channels)

    #初始化weight的mu,rho(3*3*3)
    self.w_mu = torch.zeros(size=(self.out_channels, self.in_channels, 3, 3, 3))
    self.w_rho = torch.zeros(size=(self.out_channels, self.in_channels, 3, 3, 3))
    self.w = torch.zeros(size=(self.out_channels, self.in_channels, 3, 3, 3))

def forward(self, input):
    #对bias采样,重参数
    b_epsilon = Normal(0, 1).sample(self.b_mu.shape)
    self.b = self.b_mu + torch.log(1+torch.exp(self.b_rho)) * b_epsilon

    #重新计算bias的均值以及方差,将train次数传进来n
    b_mu_1 = self.b_mu
    self.b_mu = (self.b_mu * (self.n-1) + self.b) / self.n
    self.b_rho = ((self.n-1) * self.b_rho + (self.n-1) * pow(self.b_mu - b_mu_1, 2) + pow(self.b - self.b_mu, 2)) / self.n

    # 将重参数得到的bias放进卷积层
    x = F.conv3d(input=input, weight=self.w, bias=self.b, stride=1, padding=1)
    loss = self.kl_loss()

    return x, loss

def kl_loss(self):
    loss = loss_kl()(self.w, self.w_mu, self.w_rho, self.b, self.b_mu, self.b_rho)
    return loss`