k2-fsa / snowfall

Moved to https://github.com/k2-fsa/icefall
Apache License 2.0
144 stars 42 forks source link

Random sampling with concatenation vs bucketing #108

Open pzelasko opened 3 years ago

pzelasko commented 3 years ago

FYI

I ran the transformer training with mmi + attention with different batch sampling settings to check which sampler gives better WER. The results are from the averaged model using the last 5 checkpoints (15-20 epoch). The trends using the single best checkpoint results are similar.

Default: SingleCutSampler with CutConcatenate, warmup 25k, max-frames 50000 2021-02-22 10:17:51,487 INFO [mmi_att_transformer_decode.py:300] %WER 8.15% [4285 / 52576, 674 ins, 377 del, 3234 sub ]

BucketingSampler, warmup 25k, max-frames 50000 2021-02-19 09:43:08,320 INFO [mmi_att_transformer_decode.py:300] %WER 8.30% [4362 / 52576, 710 ins, 365 del, 3287 sub ]

BucketingSampler, warmup 10k, max-frames 70000 2021-02-18 22:57:43,271 INFO [mmi_att_transformer_decode.py:300] %WER 8.31% [4370 / 52576, 697 ins, 362 del, 3311 sub ]

BucketingSampler, warmup 1k, max-frames 70000 2021-02-19 16:17:52,686 INFO [mmi_att_transformer_decode.py:300] %WER 8.32% [4372 / 52576, 709 ins, 376 del, 3287 sub ]

It seems that with bucketing we're getting a small degradation, but we are able to train faster since we can set max-frames to a larger number. I think this could be because the cut concatenation works like a regularizer for the network, i.e. forces it not to "pay attention" to the utterances that do not matter for the current recognition.

I wanted to check whether combining bucketing with concatenation would help, but I encountered an error inside the transformer - the encoder mask and encoder activation shapes are differing by 1 (e.g. 253 and 254). I don't know yet whether that is due to an issue in Lhotse or an issue in transformer code; but I don't know when I'll have enough time to debug it, so I wanted to share these findings in case somebody wants to pick it up or wonders which sampling worked better...

danpovey commented 3 years ago

Thanks!!

On Tue, Feb 23, 2021 at 10:20 AM Piotr Żelasko notifications@github.com wrote:

FYI

I ran the transformer training with mmi + attention with different batch sampling settings to check which sampler gives better WER. The results are from the averaged model using the last 5 checkpoints (15-20 epoch). The trends using the single best checkpoint results are similar.

Default: SingleCutSampler with CutConcatenate, warmup 25k, max-frames 50000 2021-02-22 10:17:51,487 INFO [mmi_att_transformer_decode.py:300] %WER 8.15% [4285 / 52576, 674 ins, 377 del, 3234 sub ]

BucketingSampler, warmup 25k, max-frames 50000 2021-02-19 09:43:08,320 INFO [mmi_att_transformer_decode.py:300] %WER 8.30% [4362 / 52576, 710 ins, 365 del, 3287 sub ]

BucketingSampler, warmup 10k, max-frames 70000 2021-02-18 22:57:43,271 INFO [mmi_att_transformer_decode.py:300] %WER 8.31% [4370 / 52576, 697 ins, 362 del, 3311 sub ]

BucketingSampler, warmup 1k, max-frames 70000 2021-02-19 16:17:52,686 INFO [mmi_att_transformer_decode.py:300] %WER 8.32% [4372 / 52576, 709 ins, 376 del, 3287 sub ]

It seems that with bucketing we're getting a small degradation, but we are able to train faster since we can set max-frames to a larger number. I think this could be because the cut concatenation works like a regularizer for the network, i.e. forces it not to "pay attention" to the utterances that do not matter for the current recognition.

I wanted to check whether combining bucketing with concatenation would help, but I encountered an error inside the transformer - the encoder mask and encoder activation shapes are differing by 1 (e.g. 253 and 254). I don't know yet whether that is due to an issue in Lhotse or an issue in transformer code; but I don't know when I'll have enough time to debug it, so I wanted to share these findings in case somebody wants to pick it up or wonders which sampling worked better...

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/snowfall/issues/108, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO6THTO4TZIRZZKI24TTAMGGDANCNFSM4YBUAVLQ .

zhu-han commented 3 years ago

I find there is a bug about start_frame after subsampling in transformers, which could lead to a wrong mask. Fixed in https://github.com/k2-fsa/snowfall/pull/109.

pzelasko commented 3 years ago

Cool! In that case I'll re-attempt this.

pzelasko commented 3 years ago

@zhu-han I tried again with your fix, but I'm still getting the following error:

  File "./mmi_att_transformer_train.py", line 104, in get_objf
    nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/exp/pzelasko/snowfall/snowfall/models/transformer.py", line 92, in forward
    encoder_memory, memory_mask = self.encode(x, supervision)
  File "/exp/pzelasko/snowfall/snowfall/models/transformer.py", line 114, in encode
    x = self.encoder(x, src_key_padding_mask=mask)  # (T, B, F)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/nn/modules/transformer.py", line 181, in forward
    output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/exp/pzelasko/snowfall/snowfall/models/transformer.py", line 230, in forward
    key_padding_mask=src_key_padding_mask)[0]
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/nn/modules/activation.py", line 985, in forward
    attn_mask=attn_mask)
  File "/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/nn/functional.py", line 4283, in multi_head_attention_forward
    assert key_padding_mask.size(1) == src_len
AssertionError

for this particular case, key_padding.mask.shape = torch.Size([20, 930]) and src_len = 931.

It should be reproducible in mmi_att_transformer_train.py if you change the line:

transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]

to

transforms = [CutConcatenate(duration_factor=2), CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]

and run it like python mmi_att_transformer_train.py --bucketing_sampler true

zhu-han commented 3 years ago

Currently, with concatenation, the mask length is computed by subsampling two sentence lengths respectively and then sum them. However, this length is different with subsampling the concatenated sentence length directly. https://github.com/k2-fsa/snowfall/pull/112 should fix it, but I'm not sure if there are other issues because I encountered another error:

Traceback (most recent call last):
  File "./mmi_att_transformer_train.py", line 619, in <module>
    main()
  File "./mmi_att_transformer_train.py", line 564, in main
    global_batch_idx_train=global_batch_idx_train,
  File "./mmi_att_transformer_train.py", line 301, in train_one_epoch
    optimizer=optimizer
  File "./mmi_att_transformer_train.py", line 142, in get_objf
    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
  File "/opt/miniconda3/lib/python3.7/site-packages/k2/dense_fsa_vec.py", line 62, in __init__
    assert duration > 0
AssertionError

This is due to there are some too short sentences which lead to zero length after subsampling. And I think it's because some sentence was mistakely truncated in https://github.com/lhotse-speech/lhotse/blob/master/lhotse/utils.py#L406

Specificly, max_frames in this functionsupervision_to_frames, i.e., cut.num_frames was failed to update to the new length when concatenate two sentences. But I haven't find how it occured exactly.

pzelasko commented 3 years ago

OK I will look into it. Thanks!

pzelasko commented 3 years ago

You were right about that issue in Lhotse - the musan mixing code sometimes truncated too much of the original utterance. I fixed it (gonna merge as soon as the tests pass), but I'm still getting mismatched padding masks and sequence lengths:

  File "/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/torch/nn/functional.py", line 4283, in multi_head_attention_forward
    assert key_padding_mask.size(1) == src_len, f'{key_padding_mask.shape} == {src_len}'
AssertionError: torch.Size([18, 884]) == 932

I verified that the num_frames truncation here (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/utils.py#L406) removes at most one frame for the batch where this happened, so there must be something else that's not right. The same thing happens both with and without bucketing sampler now. I can't see any bug on Lhotse's side at this time, but I don't exclude that possibility. Can you check it on the transformer side again?

PS. Notably, the issue only exists when you set CutConcatenate(duration_factor=2) - if it is set to 1 (default), it seems to work fine (both with and without bucketing). I think the core issue seems to be around the presence of multiple supervisions in the first cut.

pzelasko commented 3 years ago

To make things easier, I confirmed that the issue does no arise regardless of the duration_factor setting in the LSTM recipe (mmi_bigram_train.py).

zhu-han commented 3 years ago

Fix a bug in https://github.com/k2-fsa/snowfall/pull/115. Now I can run with CutConcatenate(duration_factor=2) for a few batches. But I get another mask mismatch error:

  File "/opt/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 4283, in multi_head_attention_forward
    assert key_padding_mask.size(1) == src_len, "{} == {}".format(key_padding_mask.size(1), src_len)
AssertionError: 232 == 233

, which was due to the mismatched shape in supervisions and features. In this case,

ori_supervision_segments = torch.stack(
    (supervisions['sequence_idx'],
     supervisions['start_frame'],
     supervisions['num_frames']), 1).to(torch.int32)
print(int(max(ori_supervision_segments[:, 1] + ori_supervision_segments[:, 2])))

will give 934. But feature.size[1] is 935.

I wonder is this expected or a potential bug in lhotse? If this is expected, I could further change the transformer code to alleviate this error.

pzelasko commented 3 years ago

In general, it shouldn't happen in Lhotse, at least not for LibriSpeech data. Let me check with your fix and maybe we'll be able to get to the bottom of it now.

pzelasko commented 3 years ago

Yeah it seems to me that Lhotse could be off by one frame; but I think it still makes sense to adjust the transformer code to handle the scenario when features.shape[2] is greater than max(start_frames + num_frames). It will be useful when we move on to conversational/contextual datasets, where cuts will span more audio than just the speech segment.

zhu-han commented 3 years ago

I modified the code to get max length of mask from feature. Now it should be ok to run with CutConcatenate(duration_factor=2).

pzelasko commented 3 years ago

I've got the new best result when I use bucketing + concatenation now, this is avg for the last 5 epochs:

2021-03-02 09:00:17,479 INFO [mmi_att_transformer_decode.py:300] %WER 7.90% [4151 / 52576, 620 ins, 365 del, 3166 sub ]

We should compare it to the baseline transformer results after @zhu-han re-runs them though.

danpovey commented 3 years ago

great!!

On Tuesday, March 2, 2021, Piotr Żelasko notifications@github.com wrote:

I've got the new best result when I use bucketing + concatenation now, this is avg for the last 5 epochs:

2021-03-02 09:00:17,479 INFO [mmi_att_transformer_decode.py:300] %WER 7.90% [4151 / 52576, 620 ins, 365 del, 3166 sub ]

We should compare it to the baseline transformer results after @zhu-han https://github.com/zhu-han re-runs them though.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/snowfall/issues/108#issuecomment-788933451, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO5B2NXZSOSVWZVKW63TBTWIFANCNFSM4YBUAVLQ .

zhu-han commented 3 years ago

The baseline mmi transformer result is:

2021-03-02 13:24:41,550 INFO [mmi_att_transformer_decode.py:300] %WER 7.79% [4098 / 52576, 603 ins, 342 del, 3153 sub ]
danpovey commented 3 years ago

When did you get that number? We've usually been getting a little over 8%.

On Tue, Mar 2, 2021 at 10:27 PM Han Zhu notifications@github.com wrote:

The baseline mmi transformer result is:

2021-03-02 13:24:41,550 INFO [mmi_att_transformer_decode.py:300] %WER 7.79% [4098 / 52576, 603 ins, 342 del, 3153 sub ]

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/snowfall/issues/108#issuecomment-788947403, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO2ZZ3UJV5I2LLEXL5LTBTYWPANCNFSM4YBUAVLQ .

pzelasko commented 3 years ago

These are after the recent fixes we did both in Lhotse and transformer code.

pzelasko commented 3 years ago

BTW it occurred to me that we might close the gap between bucketing and no bucketing with multi GPU training (once fixed), as each GPU will likely sample bucket of different cut lengths, so each model update will use various cut durations.

danpovey commented 3 years ago

MM yes, maybe.

On Wed, Mar 3, 2021 at 10:05 PM Piotr Żelasko notifications@github.com wrote:

BTW it occurred to me that we might close the gap between bucketing and no bucketing with multi GPU training (once fixed), as each GPU will likely sample bucket of different cut lengths, so each model update will use various cut durations.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/snowfall/issues/108#issuecomment-789737153, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO72OZW3VFTXGRIJGHDTBY6Y7ANCNFSM4YBUAVLQ .

danpovey commented 3 years ago

@pzelasko do you want to make a PR for this change, or perhaps did you already make one?

Incidentally, I managed to get basically the same numbers as the current script, in our MMI+attention setup, by reducing warm-step to 5000 and epochs from 20 to 15, which helps speed as well. But I'm not confident in this, would be nice if someone else can double check.

pzelasko commented 3 years ago

I haven’t yet — I’ll submit a PR soon.