TyXe-BDL / TyXe

MIT License
145 stars 33 forks source link

turning UNet into Bayesian UNet #25

Open taborzbislaw opened 1 year ago

taborzbislaw commented 1 year ago

Hi,

I am trying to use your library to turn UNet into a Bayesian Unet. I paste the code below: in the implementation UNet works as a pixel-to-pixel translator for 3D data. The code follows your regression example (as I am also doing regression but for higher dimensional data).

When I run the code I got a run-time error: ValueError: Expected parameter scale (Tensor of shape (4, 1, 32, 32, 16)) of distribution Normal(loc: torch.Size([4, 1, 32, 32, 16]), scale: torch.Size([4, 1, 32, 32, 16])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values...

I expect that the problem is with a wrong selection of the prior and/or guide. I would appreciate any suggestion which will make the model to learn.

Regards, Zbisław

The code:

from functools import partial

import torch import torch.nn as nn import torch.utils.data as data

import pyro import pyro.distributions as dist

import tyxe

def double_convolution(in_channels, out_channels): """ In the original paper implementation, the convolution operations were not padded but we are padding them here. This is because, we need the output result size to be same as input size. """ conv_op = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) return conv_op

class UNet(nn.Module): def init(self, num_classes): super(UNet, self).init()

    self.max_pool3d = nn.MaxPool3d(kernel_size=2, stride=2)

    # contracting path
    # each convolution is applied twice
    self.down_convolution_1 = double_convolution(3, 64)
    self.down_convolution_2 = double_convolution(64, 128)
    self.down_convolution_3 = double_convolution(128, 256)
    self.down_convolution_4 = double_convolution(256, 512)
    self.down_convolution_5 = double_convolution(512, 1024)

    # expanding path
    self.up_transpose_1 = nn.ConvTranspose3d(
        in_channels=1024, out_channels=512,
        kernel_size=2, 
        stride=2)
    # below, `in_channels` again becomes 1024 as we are concatinating
    self.up_convolution_1 = double_convolution(1024, 512)
    self.up_transpose_2 = nn.ConvTranspose3d(
        in_channels=512, out_channels=256,
        kernel_size=2, 
        stride=2)
    self.up_convolution_2 = double_convolution(512, 256)
    self.up_transpose_3 = nn.ConvTranspose3d(
        in_channels=256, out_channels=128,
        kernel_size=2, 
        stride=2)
    self.up_convolution_3 = double_convolution(256, 128)
    self.up_transpose_4 = nn.ConvTranspose3d(
        in_channels=128, out_channels=64,
        kernel_size=2, 
        stride=2)
    self.up_convolution_4 = double_convolution(128, 64)

    # output => increase the `out_channels` as per the number of classes
    self.out = nn.Conv3d(
        in_channels=64, out_channels=num_classes, 
        kernel_size=1
    ) 

def forward(self, x):
    down_1 = self.down_convolution_1(x)
    down_2 = self.max_pool3d(down_1)
    down_3 = self.down_convolution_2(down_2)
    down_4 = self.max_pool3d(down_3)
    down_5 = self.down_convolution_3(down_4)
    down_6 = self.max_pool3d(down_5)
    down_7 = self.down_convolution_4(down_6)
    #down_8 = self.max_pool3d(down_7)
    #down_9 = self.down_convolution_5(down_8)        

    #up_1 = self.up_transpose_1(down_9)
    #x = self.up_convolution_1(torch.cat([down_7, up_1], 1))

    #up_2 = self.up_transpose_2(x)
    up_2 = self.up_transpose_2(down_7)
    x = self.up_convolution_2(torch.cat([down_5, up_2], 1))

    up_3 = self.up_transpose_3(x)
    x = self.up_convolution_3(torch.cat([down_3, up_3], 1))

    up_4 = self.up_transpose_4(x)
    x = self.up_convolution_4(torch.cat([down_1, up_4], 1))

    out = self.out(x)
    return out

################################################################################

if name == 'main':

pyro.set_rng_seed(42)

x = torch.randn((4,3,32,32,32)) * 0.5 + 0.5
y = torch.mean(x,1,keepdim=True) + torch.randn((4,1,32,32,32))*0.01

for _ in range(10):
    x2 = torch.randn((4,3,32,32,32)) * 0.5 + 0.5
    x = torch.cat([x, x2])

    y2 = torch.mean(x2,1,keepdim=True) + torch.randn((4,1,32,32,32))*0.01
    y = torch.cat([y, y2])

batchSize = 4
dataset = data.TensorDataset(x, y)
loader = data.DataLoader(dataset, batch_size=batchSize)

net = UNet(1)
prior = tyxe.priors.IIDPrior(dist.Normal(0, 1))
obs_model = tyxe.likelihoods.HeteroskedasticGaussian((4,1,32,32,32))
guide = partial(tyxe.guides.AutoNormal, init_scale=0.01)
bnn = tyxe.VariationalBNN(net, prior, obs_model, guide)

pyro.clear_param_store()
optim = pyro.optim.Adam({"lr": 1e-4})
elbos = []
def callback(bnn, i, e):
    elbos.append(e)

with tyxe.poutine.local_reparameterization():
    bnn.fit(loader, optim, 10000, callback)
LarsBentsen commented 1 year ago

Hi, I'm also using HeteroskedasticGaussian likelihood and I'm not entirely sure if this will solve your issue, but in the "likelihoods.py" file it says under the HeteroskedasticGaussian class that it expects predictions to be 2d to predict both the mean and std (i.e. the aleatoric uncertainty?). For your net, it might help to change the output layer of your network to:

self.out = nn.Conv3d( in_channels=64, out_channels=num_classes * 2, kernel_size=1 ) However, I'm uncertain if this will solve your issues entirely and I apologise if this is a very trivial proposal...

taborzbislaw commented 1 year ago

Hi,

thank you for the suggestion but it does not work. I got the same effect of blowing tensors.

hpplyt commented 1 year ago

Sorry for the delayed response, I had completely missed the initial issue over Neurips.

Does the error occur right away or after a few training iterations? And have you by any chance checked if you are getting NaNs or negative values for the scales of the Normal distribution? Either would fail the constraint check iirc.

taborzbislaw commented 1 year ago

Hi.

The error occurs right after the start of the training. I got something like this:

ValueError: Expected parameter scale (Tensor of shape (4, 2, 32, 32, 16)) of distribution Normal(loc: torch.Size([4, 2, 32, 32, 16]), scale: torch.Size([4, 2, 32, 32, 16])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values: tensor([[[[[8.8535e+16, 0.0000e+00, 3.9392e+16, ..., 1.2329e+16, 5.3384e+15, 0.0000e+00], [0.0000e+00, 0.0000e+00, 4.3650e+16, ..., 5.4064e+16, 8.3643e+16, 0.0000e+00], [5.4920e+16, 8.3632e+16, 2.0063e+17, ..., 0.0000e+00, 3.9833e+16, 0.0000e+00], ...,

I expect that something may be wrong with my model. I simply copied your regression example and replaced your model with UNet which is a little bit more complicated that regression. The complete code is in the enclosed archive - the code generates its own data so it can be just run.

UNet.py.zip

hpplyt commented 1 year ago

Ok interesting. My guess would be that this is an issue with initialisation since you are getting fairly extreme values for the standard deviations. If you replace the init_scale argument for the guide with something around the standard deviation used for initialising the weights of the deterministic version of your network that might fix things (or just go with 1e-5 or so). You might also need to initialise the means of your guide to the deterministic weights as in the ResNet example.