tunz / transformer-pytorch

Transformer implementation in PyTorch.
https://tunz.kr/post/4
MIT License
464 stars 102 forks source link

Normalized Crossed Entropy and Label Smoothing #5

Closed tonyhqanguyen closed 5 years ago

tonyhqanguyen commented 5 years ago

Hi, thank you for sharing the implementation!

I was just wondering if you could explain the loss computation where you use confidence and label smoothing. I know that this is also done in the tensor2tensor repo, but I have a hard time reading and understanding this concept from that repo as well. I was reading up on Normalized Cross Entropy here, but it seems like both the formula you used here and in tensor2tensor doesn't really fit with the formula that they were talking about in that article. Could you elaborate on the implementation of that formula.

Also, since we're taking into consideration both the correct values and incorrect values, this is different than normal cross entropy? Since with cross entropy where the true values have true probability of 1 and incorrect values have probability of 0, cross entropy is only affected by the predicted probability of the true value, and it doesn't matter how the remaining probability fraction is distributed over the incorrect values. However, with the NCE formula they provided, we have to take into consideration the incorrect values as well, right?

def get_loss(pred, ans, vocab_size, label_smoothing, pad):
    # took this "normalizing" from tensor2tensor. We subtract it for
    # readability. This makes no difference on learning.
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / float(vocab_size - 1)
    normalizing = -(
        confidence * math.log(confidence) + float(vocab_size - 1) *
        low_confidence * math.log(low_confidence + 1e-20))

    one_hot = torch.zeros_like(pred).scatter_(1, ans.unsqueeze(1), 1)
    one_hot = one_hot * confidence + (1 - one_hot) * low_confidence
    log_prob = F.log_softmax(pred, dim=1)

    xent = -(one_hot * log_prob).sum(dim=1)
    xent = xent.masked_select(ans != pad)
    loss = (xent - normalizing).mean()
    return loss

I'm training a chatbot with around 150k words in my vocabulary, and so the starting iterations, each log_softmax entry is about ~-11 and so the sum over each sentence position is around 3000 when I do xent = -(one_hot * log_prob).sum(dim=1), and so the average loss is around 3000 when I take the mean over all predictions. Does this sound reasonable ... it seems like 3000 for loss is kind of off the roof?

Thanks in advance.

tunz commented 5 years ago

Hi! Since I just reimplemented the algorithm in pytorch, it would be better to ask it to tensor2tensor author if you want the clear answer. But, let me try to explain what I understood at the time of writing the code.

First, this is not the "normalized cross entropy". tensor2tensor code named this cross entropy function as smoothing_cross_entropy. I also read the NCE paper roughly, and looks like this is not related to the paper. this is just using "normalizing" constant for readability.

AFAIK, label smoothing came from an intuition that training data might have wrong labels. Large training data set usually contains quite a lot of misclassified data. So, it just gives some small confidence value even to incorrect labels so that the model does not ignore the actual label of mislabeled data in training dataset.

Then, a small problem is that cross entropy loss value becomes large. Even when the model is 100% accurate, the loss is not zero because of the label smoothing. So, we just subtract the "normalizing" constant value from the cross entropy value. Then, loss will be close to zero as the model becomes accurate. This does not affect to backward propagation, but it just make it clear to debug if the loss gets stuck or converged toward an optimal point.

tonyhqanguyen commented 5 years ago

Ah I see. Thank you very much!

tonyhqanguyen commented 5 years ago

Hi sorry to bother you again, I'm just making sure here.

The argument pred should have shape (batch size * max sequence length, vocab size) and the argument ans should have shape (batch size * sequence length,) right? Should the argument pred be pure logits computed by the Transformer model?

The problem I'm having here is that when you compute log_softmax of the inputs, the values have a pretty significant negative value, around -11, so when I add each row of log probabilities by calling .sum(dim=1) on the logged probabilities, I get around 3000 for the first few iterations. You said we subtract the normalizing constant value from the cross entropy value, but the normalizing constant, as I see here while debugging, is so small compared to the cross entropy value. The normalizing constant is < 1, and the cross entropy is 3000.

tunz commented 5 years ago

Yes, that's right.

I'm not sure what's happening there. it just could be normal unless it's not converged. But, one weird thing is the normalizing constant has to be around 12 if label_smoothing is 0.1 and vocab size is 150000.

label_smoothing=0.1
vocab_size=150000
confidence = 1.0 - label_smoothing
normalizing = -(confidence * math.log(confidence) + float(vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20))
# 12.01313558123328
tonyhqanguyen commented 5 years ago
label_smoothing=0.1
vocab_size=150000
confidence = 1.0 - label_smoothing
low_confidence = (1.0 - confidence) / float(vocab_size - 1)
normalizing = -(confidence * math.log(confidence) + float(vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20))

>>> normalizing 1.516921364030397

This is what I get ...

tunz commented 5 years ago

ah, I made a mistake. that is right.

tonyhqanguyen commented 5 years ago

Ah ok thanks. I'm not sure why the loss is off the roof when I use label smoothing right now (~3000), but when I don't, it fluctuates at around 5. Let me know if you have any insight as to what I could try.

tunz commented 5 years ago

What is the range(min/max) of your logit/softmax values for each vocab?

you said it's around 5 when label_smoothing is zero. it means

-1*log_softmax(answr_logit) == 5

and, if not, it's around 3000.

-(0.9*log_softmax(answer_logit) + sum((1/150000.0) * log_softmax(logit) for logit in other logits)) == 3000

then,

sum((1/150000.0) * log_softmax(logit) for logit in other_logits)) == -3000 + 0.9*5

If I assume that all logit values have the same value,

log_softmax(logit) == -3000 + 0.9*5

But, if their logit values have the same value, its softmax value should be around 1/150000, and log_softmax(logit) has to be around -12. It does not make sense in this case.

So, I guess the reason why your value is close to 3000 is some of your logit value is relatively smaller than others. Try to change initialization values of the embedding layer, and see how it's going. and, I still think this high loss may not be a big problem if the loss converges, or you can also reduce the value of label smoothing constant.

tonyhqanguyen commented 5 years ago

Hmm.... I'm not sure what I just changed but the loss seems pretty reasonable now, it's improving to about ~4.6 so hopefully there's an improvement. Thank you so much!