piEsposito / blitz-bayesian-deep-learning

A simple and extensible library to create Bayesian Neural Network layers on PyTorch.
GNU General Public License v3.0
918 stars 107 forks source link

High loss when using bayesian lstm instead of standard lstm #100

Open amroghoneim opened 2 years ago

amroghoneim commented 2 years ago

I am trying to implement a model using the bayesian lstm layer given I already have a model that relies on lstm and it gets good results for a classification task. When I use the bayesian layer the loss becomes very high and the accuracy doesn't converge much. I tried changing the model's hyperparameters (especially prior variables and posterior_rho) but didn't that much. I also added sharpen=True for loss sharpening but nothing changed.

The model:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
##### Bayesian version #####
from layers.lstm_bayesian_layer import BayesianLSTM
from blitz.utils import variational_estimator
from layers.linear_bayesian_layer import BayesianLinear

from layers.attention import Attention, NoQueryAttention
from layers.squeeze_embedding import SqueezeEmbedding

@variational_estimator
class LSTM_BAYES_RNN(nn.Module):
    def __init__(self, embedding_matrix, opt):
        super(LSTM_BAYES_RNN, self).__init__()
        self.lstm = BayesianLSTM(opt.embed_dim*2, opt.hidden_dim, bias=True, freeze = False,
                prior_sigma_1 = 5,
                prior_sigma_2 = 5,
                posterior_rho_init=1,
                sharpen=True)
                #  prior_pi = 1,
                #  posterior_mu_init = 0,
                #  posterior_rho_init = -6.0,
        self.opt = opt
        self.embed = nn.Embedding.from_pretrained(torch.tensor(embedding_matrix, dtype=torch.float))
        self.squeeze_embedding = SqueezeEmbedding()
        # self.dense = BayesianLinear(opt.hidden_dim, opt.polarities_dim, bias=True, freeze = False, 
                          # prior_sigma_1 = 10, prior_sigma_2 = 10, posterior_rho_init  = 5 )
        self.attention = NoQueryAttention(opt.hidden_dim+opt.embed_dim, score_function='bi_linear')

        self.dense = nn.Linear(opt.hidden_dim, opt.polarities_dim)

    def forward(self, inputs):
        text_indices, aspect_indices = inputs[0], inputs[1]
        x_len = torch.sum(text_indices != 0, dim=-1)
        x_len_max = torch.max(x_len)
        aspect_len = torch.sum(aspect_indices != 0, dim=-1).float()

        x = self.embed(text_indices)
        x = self.squeeze_embedding(x, x_len)
        aspect = self.embed(aspect_indices)
        aspect_pool = torch.div(torch.sum(aspect, dim=1), aspect_len.unsqueeze(1))
        aspect = aspect_pool.unsqueeze(1).expand(-1, x_len_max, -1)
        x = torch.cat((aspect, x), dim=-1)

        h, (_, _) = self.lstm(x)
        ha = torch.cat((h, aspect), dim=-1)
        _, score = self.attention(ha)
        output = torch.squeeze(torch.bmm(score, h), dim=1)
        out = self.dense(output)
        return out

in the training I have

                # bayesian loss calculation 
                pi_weight = minibatch_weight(batch_idx=i_batch, num_batches=self.opt.batch_size)

                loss = self.model.sample_elbo(
                        inputs=inputs,
                        labels=targets,
                        criterion=nn.CrossEntropyLoss(),
                        sample_nbr=10,
                        # complexity_cost_weight=1/len(self.trainset))
                        complexity_cost_weight = pi_weight)

                ##################

                loss.backward()
                optimizer.step()

                # take 3 outputs per example
                outputs = torch.stack([self.model(inputs) for i in range(3)])
                preds = torch.mean(outputs, axis=0)

What's the problem here?

amroghoneim commented 2 years ago

This is how the loss looks like within the sample elbo method for multiple samples over many epochs

PERFORMANCE LOSS: 1.3859792947769165 PERFORMANCE + KL LOSS: 446.3642578125 PERFORMANCE LOSS: 447.8734436035156 PERFORMANCE + KL LOSS: 892.888916015625 PERFORMANCE LOSS: 894.4247436523438 PERFORMANCE + KL LOSS: 1340.0321044921875 PERFORMANCE LOSS: 1341.6724853515625 PERFORMANCE + KL LOSS: 1786.862548828125 PERFORMANCE LOSS: 1788.3409423828125 PERFORMANCE + KL LOSS: 2233.776611328125 PERFORMANCE LOSS: 2235.102783203125 PERFORMANCE + KL LOSS: 2680.375244140625 PERFORMANCE LOSS: 2681.85400390625 PERFORMANCE + KL LOSS: 3127.18310546875 PERFORMANCE LOSS: 3128.93798828125 PERFORMANCE + KL LOSS: 3574.3134765625 PERFORMANCE LOSS: 3576.058349609375 PERFORMANCE + KL LOSS: 4021.171630859375 PERFORMANCE LOSS: 4023.035888671875 PERFORMANCE + KL LOSS: 4468.0283203125 PERFORMANCE LOSS: 1.2040766477584839 PERFORMANCE + KL LOSS: 446.34429931640625 PERFORMANCE LOSS: 447.7460632324219 PERFORMANCE + KL LOSS: 892.8301391601562 PERFORMANCE LOSS: 894.1790161132812 PERFORMANCE + KL LOSS: 1339.896240234375 PERFORMANCE LOSS: 1341.173583984375 PERFORMANCE + KL LOSS: 1786.486083984375 PERFORMANCE LOSS: 1787.7799072265625 PERFORMANCE + KL LOSS: 2232.6181640625 PERFORMANCE LOSS: 2234.0400390625 PERFORMANCE + KL LOSS: 2679.76123046875 PERFORMANCE LOSS: 2680.916259765625 PERFORMANCE + KL LOSS: 3126.09130859375 PERFORMANCE LOSS: 3127.1962890625 PERFORMANCE + KL LOSS: 3572.398681640625 PERFORMANCE LOSS: 3573.755126953125 PERFORMANCE + KL LOSS: 4019.00634765625 PERFORMANCE LOSS: 4020.251708984375 PERFORMANCE + KL LOSS: 4465.36572265625 PERFORMANCE LOSS: 1.2470617294311523 PERFORMANCE + KL LOSS: 446.4764099121094 PERFORMANCE LOSS: 447.49908447265625 PERFORMANCE + KL LOSS: 892.6083984375 PERFORMANCE LOSS: 893.86865234375 PERFORMANCE + KL LOSS: 1338.82666015625 PERFORMANCE LOSS: 1340.1925048828125 PERFORMANCE + KL LOSS: 1785.356201171875 PERFORMANCE LOSS: 1786.67041015625 PERFORMANCE + KL LOSS: 2231.896484375 PERFORMANCE LOSS: 2233.141845703125 PERFORMANCE + KL LOSS: 2678.32861328125 PERFORMANCE LOSS: 2679.552734375 PERFORMANCE + KL LOSS: 3124.926513671875 PERFORMANCE LOSS: 3126.08837890625 PERFORMANCE + KL LOSS: 3571.53564453125 PERFORMANCE LOSS: 3572.876953125 PERFORMANCE + KL LOSS: 4017.953857421875 PERFORMANCE LOSS: 4019.236572265625 PERFORMANCE + KL LOSS: 4464.65234375 PERFORMANCE LOSS: 1.2529895305633545 PERFORMANCE + KL LOSS: 446.598876953125 PERFORMANCE LOSS: 447.7634582519531 PERFORMANCE + KL LOSS: 893.0445556640625 PERFORMANCE LOSS: 894.239990234375 PERFORMANCE + KL LOSS: 1339.1597900390625 PERFORMANCE LOSS: 1340.3812255859375 PERFORMANCE + KL LOSS: 1785.3369140625 PERFORMANCE LOSS: 1786.525634765625 PERFORMANCE + KL LOSS: 2231.73193359375 PERFORMANCE LOSS: 2232.955078125 PERFORMANCE + KL LOSS: 2678.22802734375 PERFORMANCE LOSS: 2679.33935546875 PERFORMANCE + KL LOSS: 3124.440185546875 PERFORMANCE LOSS: 3125.568359375 PERFORMANCE + KL LOSS: 3570.59130859375 PERFORMANCE LOSS: 3571.839111328125 PERFORMANCE + KL LOSS: 4017.313720703125 PERFORMANCE LOSS: 4018.48046875 PERFORMANCE + KL LOSS: 4463.61572265625 PERFORMANCE LOSS: 1.3086638450622559 PERFORMANCE + KL LOSS: 446.3980712890625 PERFORMANCE LOSS: 447.6965026855469 PERFORMANCE + KL LOSS: 893.017578125 PERFORMANCE LOSS: 894.0802612304688 PERFORMANCE + KL LOSS: 1339.434814453125 PERFORMANCE LOSS: 1340.6083984375 PERFORMANCE + KL LOSS: 1786.0252685546875 PERFORMANCE LOSS: 1787.18115234375 PERFORMANCE + KL LOSS: 2232.149169921875 PERFORMANCE LOSS: 2233.2109375 PERFORMANCE + KL LOSS: 2678.623779296875 PERFORMANCE LOSS: 2679.775390625 PERFORMANCE + KL LOSS: 3125.24609375 PERFORMANCE LOSS: 3126.44775390625 PERFORMANCE + KL LOSS: 3571.93017578125 PERFORMANCE LOSS: 3573.00390625 PERFORMANCE + KL LOSS: 4017.955078125 PERFORMANCE LOSS: 4019.065185546875 PERFORMANCE + KL LOSS: 4464.07568359375 PERFORMANCE LOSS: 0.9791557788848877 PERFORMANCE + KL LOSS: 445.9003601074219 PERFORMANCE LOSS: 446.8746032714844 PERFORMANCE + KL LOSS: 892.022705078125 PERFORMANCE LOSS: 892.9429321289062 PERFORMANCE + KL LOSS: 1338.173095703125 PERFORMANCE LOSS: 1339.390380859375 PERFORMANCE + KL LOSS: 1784.44775390625 PERFORMANCE LOSS: 1785.507568359375 PERFORMANCE + KL LOSS: 2230.6630859375 PERFORMANCE LOSS: 2231.741455078125 PERFORMANCE + KL LOSS: 2676.9111328125 PERFORMANCE LOSS: 2677.945556640625 PERFORMANCE + KL LOSS: 3122.991943359375 PERFORMANCE LOSS: 3123.878173828125 PERFORMANCE + KL LOSS: 3569.328857421875 PERFORMANCE LOSS: 3570.427978515625 PERFORMANCE + KL LOSS: 4015.494873046875 PERFORMANCE LOSS: 4016.546142578125 PERFORMANCE + KL LOSS: 4461.78369140625 PERFORMANCE LOSS: 1.034562349319458 PERFORMANCE + KL LOSS: 446.2547912597656 PERFORMANCE LOSS: 447.161865234375 PERFORMANCE + KL LOSS: 892.19677734375 PERFORMANCE LOSS: 893.2605590820312 PERFORMANCE + KL LOSS: 1338.6982421875 PERFORMANCE LOSS: 1339.7855224609375 PERFORMANCE + KL LOSS: 1785.085205078125 PERFORMANCE LOSS: 1786.3265380859375 PERFORMANCE + KL LOSS: 2231.59912109375 PERFORMANCE LOSS: 2232.564453125 PERFORMANCE + KL LOSS: 2677.362548828125 PERFORMANCE LOSS: 2678.36328125 PERFORMANCE + KL LOSS: 3123.44775390625 PERFORMANCE LOSS: 3124.4931640625 PERFORMANCE + KL LOSS: 3569.76220703125 PERFORMANCE LOSS: 3570.87060546875 PERFORMANCE + KL LOSS: 4015.784912109375 PERFORMANCE LOSS: 4016.75537109375 PERFORMANCE + KL LOSS: 4461.736328125 PERFORMANCE LOSS: 1.1334476470947266 PERFORMANCE + KL LOSS: 446.2291259765625 PERFORMANCE LOSS: 447.3734436035156 PERFORMANCE + KL LOSS: 892.83544921875 PERFORMANCE LOSS: 893.9830322265625 PERFORMANCE + KL LOSS: 1339.30078125 PERFORMANCE LOSS: 1340.4144287109375 PERFORMANCE + KL LOSS: 1785.761962890625 PERFORMANCE LOSS: 1787.0693359375 PERFORMANCE + KL LOSS: 2232.11376953125 PERFORMANCE LOSS: 2233.1796875 PERFORMANCE + KL LOSS: 2678.37451171875 PERFORMANCE LOSS: 2679.556396484375 PERFORMANCE + KL LOSS: 3124.7421875 PERFORMANCE LOSS: 3125.9501953125 PERFORMANCE + KL LOSS: 3571.496826171875 PERFORMANCE LOSS: 3572.64404296875 PERFORMANCE + KL LOSS: 4017.846923828125 PERFORMANCE LOSS: 4018.96044921875 PERFORMANCE + KL LOSS: 4464.28173828125 PERFORMANCE LOSS: 1.0037894248962402 PERFORMANCE + KL LOSS: 445.89324951171875 PERFORMANCE LOSS: 446.8221435546875 PERFORMANCE + KL LOSS: 891.6876220703125 PERFORMANCE LOSS: 892.6782836914062 PERFORMANCE + KL LOSS: 1337.7032470703125 PERFORMANCE LOSS: 1338.6993408203125 PERFORMANCE + KL LOSS: 1783.60986328125 PERFORMANCE LOSS: 1784.5438232421875 PERFORMANCE + KL LOSS: 2229.855712890625 PERFORMANCE LOSS: 2230.88720703125 PERFORMANCE + KL LOSS: 2675.54638671875 PERFORMANCE LOSS: 2676.505615234375 PERFORMANCE + KL LOSS: 3121.80419921875 PERFORMANCE LOSS: 3122.758056640625 PERFORMANCE + KL LOSS: 3567.829345703125 PERFORMANCE LOSS: 3568.846923828125 PERFORMANCE + KL LOSS: 4013.456298828125 PERFORMANCE LOSS: 4014.486328125 PERFORMANCE + KL LOSS: 4459.689453125 PERFORMANCE LOSS: 1.1406904458999634 PERFORMANCE + KL LOSS: 446.54901123046875 PERFORMANCE LOSS: 447.63385009765625 PERFORMANCE + KL LOSS: 892.46337890625 PERFORMANCE LOSS: 893.5294189453125 PERFORMANCE + KL LOSS: 1338.5889892578125 PERFORMANCE LOSS: 1339.5230712890625 PERFORMANCE + KL LOSS: 1784.5792236328125 PERFORMANCE LOSS: 1785.79296875 PERFORMANCE + KL LOSS: 2230.992919921875 PERFORMANCE LOSS: 2232.022216796875 PERFORMANCE + KL LOSS: 2677.153076171875 PERFORMANCE LOSS: 2678.14990234375 PERFORMANCE + KL LOSS: 3122.89013671875 PERFORMANCE LOSS: 3123.91357421875 PERFORMANCE + KL LOSS: 3569.23095703125 PERFORMANCE LOSS: 3570.322021484375 PERFORMANCE + KL LOSS: 4015.341796875 PERFORMANCE LOSS: 4016.3662109375 PERFORMANCE + KL LOSS: 4461.22021484375 loss: 446.3445, acc: 0.4250 PERFORMANCE LOSS: 1.185144066810608 PERFORMANCE + KL LOSS: 446.47686767578125 PERFORMANCE LOSS: 447.7961730957031 PERFORMANCE + KL LOSS: 893.0511474609375 PERFORMANCE LOSS: 894.4295043945312 PERFORMANCE + KL LOSS: 1339.2666015625 PERFORMANCE LOSS: 1340.646240234375 PERFORMANCE + KL LOSS: 1785.5106201171875 PERFORMANCE LOSS: 1786.9635009765625 PERFORMANCE + KL LOSS: 2231.67919921875 PERFORMANCE LOSS: 2233.009521484375 PERFORMANCE + KL LOSS: 2678.06396484375 PERFORMANCE LOSS: 2679.30859375 PERFORMANCE + KL LOSS: 3124.505859375 PERFORMANCE LOSS: 3125.79931640625 PERFORMANCE + KL LOSS: 3570.76025390625 PERFORMANCE LOSS: 3572.018310546875 PERFORMANCE + KL LOSS: 4017.48828125 PERFORMANCE LOSS: 4018.740966796875 PERFORMANCE + KL LOSS: 4463.671875 PERFORMANCE LOSS: 1.2589664459228516 PERFORMANCE + KL LOSS: 446.436767578125 PERFORMANCE LOSS: 447.7485046386719 PERFORMANCE + KL LOSS: 892.6570434570312 PERFORMANCE LOSS: 893.7003784179688 PERFORMANCE + KL LOSS: 1338.519287109375 PERFORMANCE LOSS: 1339.8040771484375 PERFORMANCE + KL LOSS: 1784.7545166015625 PERFORMANCE LOSS: 1785.898193359375 PERFORMANCE + KL LOSS: 2231.010498046875 PERFORMANCE LOSS: 2232.1064453125 PERFORMANCE + KL LOSS: 2677.271240234375 PERFORMANCE LOSS: 2678.56689453125 PERFORMANCE + KL LOSS: 3123.6142578125 PERFORMANCE LOSS: 3124.657470703125 PERFORMANCE + KL LOSS: 3569.591064453125 PERFORMANCE LOSS: 3570.9541015625 PERFORMANCE + KL LOSS: 4016.14990234375 PERFORMANCE LOSS: 4017.164306640625 PERFORMANCE + KL LOSS: 4462.57470703125 PERFORMANCE LOSS: 0.9719462990760803 PERFORMANCE + KL LOSS: 445.69403076171875 PERFORMANCE LOSS: 446.7107849121094 PERFORMANCE + KL LOSS: 891.2470092773438 PERFORMANCE LOSS: 892.1383056640625 PERFORMANCE + KL LOSS: 1337.050537109375 PERFORMANCE LOSS: 1337.9063720703125 PERFORMANCE + KL LOSS: 1782.8287353515625 PERFORMANCE LOSS: 1783.8055419921875 PERFORMANCE + KL LOSS: 2228.96337890625 PERFORMANCE LOSS: 2229.83935546875 PERFORMANCE + KL LOSS: 2674.52294921875 PERFORMANCE LOSS: 2675.642333984375 PERFORMANCE + KL LOSS: 3120.78564453125 PERFORMANCE LOSS: 3121.7060546875 PERFORMANCE + KL LOSS: 3566.986328125 PERFORMANCE LOSS: 3567.8515625 PERFORMANCE + KL LOSS: 4012.77001953125 PERFORMANCE LOSS: 4013.5654296875 PERFORMANCE + KL LOSS: 4458.6123046875 PERFORMANCE LOSS: 0.998976469039917 PERFORMANCE + KL LOSS: 446.0976867675781 PERFORMANCE LOSS: 447.0158386230469 PERFORMANCE + KL LOSS: 891.9049072265625 PERFORMANCE LOSS: 892.7996215820312 PERFORMANCE + KL LOSS: 1337.7274169921875 PERFORMANCE LOSS: 1338.74462890625 PERFORMANCE + KL LOSS: 1783.6373291015625 PERFORMANCE LOSS: 1784.5079345703125 PERFORMANCE + KL LOSS: 2229.2099609375 PERFORMANCE LOSS: 2230.24609375 PERFORMANCE + KL LOSS: 2675.1591796875 PERFORMANCE LOSS: 2676.1318359375 PERFORMANCE + KL LOSS: 3121.31396484375 PERFORMANCE LOSS: 3122.311279296875 PERFORMANCE + KL LOSS: 3567.53173828125 PERFORMANCE LOSS: 3568.599609375 PERFORMANCE + KL LOSS: 4013.51220703125 PERFORMANCE LOSS: 4014.64306640625 PERFORMANCE + KL LOSS: 4459.9609375 PERFORMANCE LOSS: 0.9898296594619751 PERFORMANCE + KL LOSS: 446.1213073730469 PERFORMANCE LOSS: 447.1927185058594 PERFORMANCE + KL LOSS: 891.9871215820312 PERFORMANCE LOSS: 893.0206909179688 PERFORMANCE + KL LOSS: 1338.072998046875 PERFORMANCE LOSS: 1338.9912109375 PERFORMANCE + KL LOSS: 1784.236572265625 PERFORMANCE LOSS: 1785.2718505859375 PERFORMANCE + KL LOSS: 2229.968994140625 PERFORMANCE LOSS: 2230.87890625 PERFORMANCE + KL LOSS: 2676.033203125 PERFORMANCE LOSS: 2677.09619140625 PERFORMANCE + KL LOSS: 3121.7255859375 PERFORMANCE LOSS: 3122.779296875 PERFORMANCE + KL LOSS: 3567.864501953125 PERFORMANCE LOSS: 3568.97705078125 PERFORMANCE + KL LOSS: 4013.9189453125 PERFORMANCE LOSS: 4014.874267578125 PERFORMANCE + KL LOSS: 4460.056640625

Philippe-Drolet commented 2 years ago

Having the same issue here...

swjtufjs commented 2 years ago

Hello, have you finally solved it? I had the same problem

0wenwu commented 1 year ago

Have the same problem. And how to output the uncertainty of prediction result?

piEsposito commented 1 year ago

Hey, I'm sorry for the delay. Will try to take a look at it this week.

Maybe reducing the KL Divergence weight on the loss could help.

@0wenwu to output the uncertainty you do multiple forward passes and check the variance, you can assume it is a normal.

0wenwu commented 1 year ago

Hey, I'm sorry for the delay. Will try to take a look at it this week.

Maybe reducing the KL Divergence weight on the loss could help.

@0wenwu to output the uncertainty you do multiple forward passes and check the variance, you can assume it is a normal.

Hey, @piEsposito I have solved the problem, thank you for your excellent work. What we need to do is reading the stocks-blstm.ipynb carefully. What do you think of the effect of sample_nbr on the predict loss?