Open manipopopo opened 4 months ago
I am having the same issue here.
In my opinion, KL divergence should have the same effect as cross entropy loss, since in the code, the target is detached, and these two losses differ only by the entropy of the target. However, replacing the cross entropy loss with KL divergence make the model fail to converge.
The reason might be numerical issues of pytorch, or as is mentioned, the misuse of nn.CrossEntropyLoss
, or other factors...
The
ShiftCrossEntropy
currently utilizesnn.CrossEntropyLoss
as its backend, which expects the input to be unnormalized logits. It appears thatShiftCrossEntropy
passes input probabilities and target probabilities to the backend instead. This might lead to a deviation from the expected behavior described in equation (7) of the paper.https://github.com/SonyCSLParis/pesto-full/blob/229f78bd96986bdead3402331488a904b632f9cd/src/losses/entropy.py#L49