SonyCSLParis / pesto-full

Full models and training code for PESTO
GNU Lesser General Public License v3.0
51 stars 14 forks source link

`ShiftCrossEntropy` passing probabilities to `nn.CrossEntropyLoss` instead of logits #4

Open manipopopo opened 4 months ago

manipopopo commented 4 months ago

The ShiftCrossEntropy currently utilizes nn.CrossEntropyLoss as its backend, which expects the input to be unnormalized logits. It appears that ShiftCrossEntropy passes input probabilities and target probabilities to the backend instead. This might lead to a deviation from the expected behavior described in equation (7) of the paper.

https://github.com/SonyCSLParis/pesto-full/blob/229f78bd96986bdead3402331488a904b632f9cd/src/losses/entropy.py#L49

suncerock commented 1 month ago

I am having the same issue here.

In my opinion, KL divergence should have the same effect as cross entropy loss, since in the code, the target is detached, and these two losses differ only by the entropy of the target. However, replacing the cross entropy loss with KL divergence make the model fail to converge.

The reason might be numerical issues of pytorch, or as is mentioned, the misuse of nn.CrossEntropyLoss, or other factors...