saahiluppal / catr

Image Captioning Using Transformer
Apache License 2.0
254 stars 53 forks source link

Consideration of padding? #12

Closed ohwi closed 3 years ago

ohwi commented 3 years ago

Hi! I studied your code, and got some questions.

It seems like pad token is making a loss, too.

https://github.com/saahiluppal/catr/blob/8a1f7704a9926ff4e785c3f8bc05910cc6f0f397/models/caption.py#L55

In my opinion, the code above need to be like:

criterion = torch.nn.CrossEntropyLoss(ignore_index=config.pad_token_id)

Otherwise, the model must predict [PAD] token, too.

Also, I wonder the reason why you used FrozenBatchNorm. Was batch size 32 not sufficient for stable learning?

Thank you!!

saahiluppal commented 3 years ago

Hey,

consider this example throughout

token#N   :          0    1       2      3       4       5      6     7     8     9
predicted : [START] [A] [man] [holding] [a] [cellphone] [in] [rain] [END] [PAD] [PAD].
actual    : [START] [A] [man] [holding] [a] [cellphone] [END] [PAD] [PAD] [PAD] [PAD].

token#N refers to token number

Yes, we do need to consider pad token in loss as well because

I guess your actual concern in why we didn't discarded token#N 8 and 9 because predicted and actual tokens are same i.e. [PAD].

Well yes, the softmax function in cross-entropy will make sure that even though both tokens are same, there will be some (very small) loss generated through them as well which will add up to the total loss. Consider this as noise we inject to this model in hope to make this model robust. Thinks of this as even though token#N 0, 1, 2, 3, 4 have same predicted and actual value, we never discard them and always include them in our loss.

And there's no specific reason to use FrozenBatchNorm, but still it provided us a faster training speed and i personally don't like to use Batch Norm because it sometimes kinda suck while inference.

ohwi commented 3 years ago

Hi. Thank you for your careful reply!

I understand your point.

After I read your comment, I've searched how huggingface's seq2seq model is trained, both ignoring and not ignoring are supported: link

Since your model works well, not ignoring may be better for image captioning.