k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
921 stars 294 forks source link

Mismatching between custom LabelSmoothing and and PyTorch's Label Smoothing #107

Open csukuangfj opened 2 years ago

csukuangfj commented 2 years ago

The current LabelSmooting in icefall (see below) https://github.com/k2-fsa/icefall/blob/810b193dcc3ad3f7a65bc3def63493711c9a084e/egs/librispeech/ASR/conformer_ctc/transformer.py#L797

is based on ESPnet, which seems to be based on The Annotated Transformer.

As @janvainer pointed out in #106, there is a built-in LabelSmooting loss in torch >=1.10, see https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss

I just compared the implementation from PyTorch with the one in icefall and identified the following differences.

PyTorch's implementation follows the one described in the paper Attention Is All You Need

LabelSmoothing is proposed by the paper Rethinking the Inception Architecture for Computer Vision, which has the following formula:

Screen Shot 2021-11-05 at 7 27 49 PM

First difference

icefall is using K - 1, not K. See the code below from icefall https://github.com/k2-fsa/icefall/blob/810b193dcc3ad3f7a65bc3def63493711c9a084e/egs/librispeech/ASR/conformer_ctc/transformer.py#L853

Second difference

Rethinking the Inception Architecture for Computer Vision uses cross-entropy to compute the loss. See the formula from the paper, given below:

Screen Shot 2021-11-05 at 7 32 21 PM

but icefall uses the following formula, i.e., KL-divergence CodeCogsEqn


To match PyTorch's implementation (also the one used in the original transformer paper), we have to do the following changes:

(1) Change https://github.com/k2-fsa/icefall/blob/810b193dcc3ad3f7a65bc3def63493711c9a084e/egs/librispeech/ASR/conformer_ctc/transformer.py#L853 to

true_dist.fill_(self.smoothing / self.size)

Also, we need to add self.smoothing / self.size to the target positions in true_dist.

That is, change https://github.com/k2-fsa/icefall/blob/810b193dcc3ad3f7a65bc3def63493711c9a084e/egs/librispeech/ASR/conformer_ctc/transformer.py#L857

to use scatter_add_

            true_dist.scatter_add_(
                1,
                target.unsqueeze(1),
                torch.full(true_dist.size(), fill_value=self.confidence).to(true_dist),
            )

(2) Change https://github.com/k2-fsa/icefall/blob/810b193dcc3ad3f7a65bc3def63493711c9a084e/egs/librispeech/ASR/conformer_ctc/transformer.py#L858 to

lable_smoothing_loss = -1 * (torch.log_softmax(x, dim=1) * true_dist).sum(dim=1)

@danpovey Do you think we should make the above changes?

danpovey commented 2 years ago

Yes I agree that it would be simpler and easier to understand.

janvainer commented 2 years ago

Thanks for doing the comparison!