k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
876 stars 285 forks source link

Problem with valid loss functions #97

Open danpovey opened 2 years ago

danpovey commented 2 years ago

Since we merged a change to asr_datamodule.py (sorry dont have time to the find the PR), our valid loss function for attention is very bad, which affects diagnostics (but not decoding).

The issue seems to be related to reordering ("indices") that is done in encode_supervisions(); the supervision for the attention decoder is taken from there, but it looks like we are not properly taking into account any reordering. I have verified that the "indices" variable in encode_supervisions() always seems to be in order for train data, but for some reason, not for valid. I won't be fixing this tonight, as it's late right now, we'll fix this to-morrow.

Making an issue in case anyone else notices the problem.

csukuangfj commented 2 years ago

Since we merged a change to asr_datamodule.py (sorry dont have time to the find the PR)

It is https://github.com/k2-fsa/icefall/pull/73 .

I have verified that the "indices" variable in encode_supervisions() always seems to be in order for train data, but for some reason, not for valid.

I don't think so. The "indices" for valid should also be in order (I just verified it with one batch). The reason is that lhotse sorts cuts by duration in descending order before returning them and we are using torch.argsort to sort them by duration on CPU, which is a stable sort. So the sort actually does nothing and encode_supervisions is in fact a no-op.

Relevant code is given below: https://github.com/k2-fsa/icefall/blob/8cb7f712e413fffbcdfdd865be73d6ff43f0ce7a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py#L268-L271

https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L96

cuts = cuts.sort_by_duration(ascending=False)

https://github.com/k2-fsa/icefall/blob/8cb7f712e413fffbcdfdd865be73d6ff43f0ce7a/icefall/utils.py#L260-L265


the supervision for the attention decoder is taken from there, but it looks like we are not properly taking into account any reordering

I agree this is a potential bug if lhotse did not sort the returned cuts by duration.


I just compared the valid loss by changing shuffle. It turns out the reason for large valid loss is shuffle=False.

valid loss with shuffle=True

2021-10-29 08:17:40,751 INFO [train2.py:567] Computing validation loss
2021-10-29 08:18:26,353 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.0643, att_loss=0.08654, loss=0.07987, over 944034 frames.

valid loss with shuffle=False

2021-10-29 08:20:26,927 INFO [train2.py:567] Computing validation loss
2021-10-29 08:21:03,460 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.06419, att_loss=0.3728, loss=0.2802, over 944034 frames.

There are two additional differences between valid sampler and train sampler:

valid loss with shuffle=True, bucket_method=equal_duration

(A little better than bucket_method==equal_len)

2021-10-29 08:14:39,988 INFO [train2.py:567] Computing validation loss
2021-10-29 08:15:22,140 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.06426, att_loss=0.07458, loss=0.07148, over 944034 frames.

valid loss with shuffle=True, bucket_method=equal_duration, num_buckets=30

(A little worse than num_buckets==10)

2021-10-29 08:35:59,972 INFO [train2.py:567] Computing validation loss
2021-10-29 08:36:38,588 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.06439, att_loss=0.1334, loss=0.1127, over 943155 frames.
pzelasko commented 2 years ago

I read over Lhotse’s code again but I can’t spot any reordering related issues. An easy way to ensure that cuts are sorted by duration is to specify return_cuts=True to the dataset’s constructor and inspect them in the training loop (they are in batch[„cuts”]).

If the loss grows with smaller num_buckets it suggests that the issue arises from excessive padding. Maybe extra frames with low values (-23 or so for log-energies) are shifting the normalization statistics (batchnorm/layernorm) too much?

pzelasko commented 2 years ago

… one solution would be to make the padding use mean value for each feature bin instead of low energy. I can look into it but not sure if I’ll manage to do it this week.

csukuangfj commented 2 years ago

I suspect that the problem is due to padding.

The following shows how the valid att loss changes when I change --max-duration for the valid sampler. (Note: shuffle is set to False all the time. Only max_duration is changed)

max duration == 10

2021-10-29 10:25:54,562 INFO [train2.py:567] Computing validation loss
/ceph-fj/fangjun/open-source/lhotse-ali-ctc-new/lhotse/dataset/sampling/single_cut.py:237: UserWarning: The first cut drawn in batch
collection violates the max_frames, max_cuts, or max_duration constraints - we'll return it anyway. Consider increasing max_frames/max_cuts/max_duration.
  warnings.warn(
2021-10-29 10:29:12,491 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.06411, att_loss=0.03271, loss=0.04213, over 944034 frames.

max duration == 20

2021-10-29 10:33:12,160 INFO [train2.py:567] Computing validation loss
/ceph-fj/fangjun/open-source/lhotse-ali-ctc-new/lhotse/dataset/sampling/single_cut.py:237: UserWarning: The first cut drawn in batch
collection violates the max_frames, max_cuts, or max_duration constraints - we'll return it anyway. Consider increasing max_frames/max_cuts/max_duration.
  warnings.warn(
2021-10-29 10:35:16,893 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.06411, att_loss=0.03268, loss=0.04211, over 944034 frames.

max duration 40

2021-10-29 10:38:09,161 INFO [train2.py:567] Computing validation loss
2021-10-29 10:39:17,043 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.06411, att_loss=0.04109, loss=0.048, over 944034 frames.

max duration 60

2021-10-29 10:41:33,061 INFO [train2.py:567] Computing validation loss
2021-10-29 10:42:24,839 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.06411, att_loss=0.08188, loss=0.07655, over 944034 frames.

max duration 80

2021-10-29 10:45:46,433 INFO [train2.py:567] Computing validation loss
2021-10-29 10:46:31,573 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.0641, att_loss=0.1337, loss=0.1128, over 944034 frames.

max duration 120

2021-10-29 10:49:34,502 INFO [train2.py:567] Computing validation loss
2021-10-29 10:50:15,011 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.06412, att_loss=0.2358, loss=0.1843, over 944034 frames.

max duration 180

2021-10-29 10:52:21,306 INFO [train2.py:567] Computing validation loss
2021-10-29 10:52:58,819 INFO [train2.py:577] Epoch 50, validation: ctc_loss=0.06414, att_loss=0.3501, loss=0.2643, over 944034 frames.

You can see that valid att_loss gets worse as we increase max duration.

danpovey commented 2 years ago

I've printed out the 'indices' and they hae not been sorted for validation data, with shuffle=False at least (not sure about true).

However when I've printed the num_frames, they have always been in order, that I've seen. Will investigate more.

In the longer term I think we should aim to modify the intersection code so that it supports things that are not sorted by duration, since this whole constraint is a bit ugly.

danpovey commented 2 years ago

I think I found the problem, argsort not being stable. Code:

    print("segments = ", supervision_segments[:,2])
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    print("indices=", indices)

Some output:

segments =  tensor([**150, 150**, 148, 143, 138, 135, 131, 124, 122, 121, 119, 115, 115, 113,
        109, 107, 106, 105, 101, 101, 100,  99,  98,  97,  97,  96,  95,  94,
         94,  93,  91,  90,  90,  87,  85,  81,  80,  76,  70,  70,  67,  64,
         63,  62,  61,  60,  57,  54,  52,  50,  45], dtype=torch.int32)
indices= tensor([ **1,  0,**  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50])
danpovey commented 2 years ago

It looks like torch.argsort calls torch.sort, but likely does not specify 'stable=True'.. the default of torch.sort is stable=False.