IntelLabs / bayesian-torch

A library for Bayesian neural network layers and uncertainty estimation in Deep Learning extending the core of PyTorch
BSD 3-Clause "New" or "Revised" License
544 stars 73 forks source link

KL divergence not changing during training #19

Closed gkwt closed 1 year ago

gkwt commented 1 year ago

Hello,

I am trying to make a single layer BNN using the LinearReparameterization layer. I am unable to get it to give reasonable uncertainty estimates, so I started monitoring the KL term from the layers and noticed that it is not changing at all for each epoch. Even when I scale up the KL term in the loss, it remains unchanged.

I am not sure if this is a bug, or if I am not doing the training correctly.

My model

class BNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=1):
        super().__init__()
        self.layer1 = LinearReparameterization(input_dim, hidden_dim)
        self.layerf = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        kl_sum = 0
        x, kl = self.layer1(x)
        kl_sum += kl
        x = F.relu(x)
        x = self.layerf(x)
        return x, kl_sum

and my training loop

model = BNN(X_train.shape[-1], 100).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

for epoch in pbar:
        running_kld_loss = 0
        running_mse_loss = 0
        running_loss = 0
        for datapoints, labels in dataloader_train:
            optimizer.zero_grad()

            output, kl = model(datapoints)
            kl = get_kl_loss(model)

            # calculate loss with kl term for Bayesian layers
            mse_loss = criterion(output, labels)
            loss = mse_loss + kl * kld_beta / batch_size

            loss.backward()
            optimizer.step()

            running_mse_loss += mse_loss.detach().numpy()
            running_kld_loss += kl.detach().numpy()
            running_loss += loss.detach().numpy()

        status.update({
            'Epoch': epoch, 
            'loss': running_loss/len(dataloader_train),
            'kl': running_kld_loss/len(dataloader_train),
            'mse': running_mse_loss/len(dataloader_train)
        })

When I print the KL loss, it starts at ~5.0 and does not decrease at all.

ranganathkrishnan commented 1 year ago

Hi @gkwt,

You seem to be already getting kl value from the model, can you try commenting out the get_kl_loss as below?

      output, kl = model(datapoints)
      #kl = get_kl_loss(model)
gkwt commented 1 year ago

The problem persists even without the get_kl_loss function. I should note that the values are the same as before. The backpropagation still does not change the KL value.

gkwt commented 1 year ago

I have also tried this with LinearFlipout. It seems that the KL is not affected by the optimizer. After initialization of the model, I also added

for param in model.parameters():
    param.requires_grad = True

to unfreeze the layers. But it had no effect on the training. KL remains constant

gkwt commented 1 year ago

Hi @ranganathkrishnan,

There was a bug in my training loop. I was overwriting the model, and so the KL was not changing. Sorry for the confusion!