k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
876 stars 285 forks source link

--concatenate-cuts #1426

Closed joazoa closed 6 months ago

joazoa commented 8 months ago

Hello!

Can somebody help me with the concatenate cuts?

When i enable it, I always get this error:

zipformer/model.py", line 324, in forward assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) AssertionError: (torch.Size([39, 1540, 80]), torch.Size([40]), 40)

I'm also using musan and specaugment as well as on the fly loading of the files.

What can i do to get the concatenation to work ?

pzelasko commented 8 months ago

You’d probably need to add cuts = cuts.merge_supervisions() in K2SpeechRecognitionDataset right after the concatenation transform is executed, so the number of cuts and supervisions is equal.

daniel-dona commented 2 months ago

I think this is still a problem as It's implemented in all of the asr_datamodule, tried using merge_supervisions() but I'm not sure if this is the proper way:

        if self.args.concatenate_cuts:
            logging.info(
                f"Using cut concatenation with duration factor "
                f"{self.args.duration_factor} and gap {self.args.gap}."
            )
            # Cut concatenation should be the first transform in the list,
            # so that if we e.g. mix noise in, it will fill the gaps between
            # different utterances.
            transforms = [
                CutConcatenate(
                    duration_factor=self.args.duration_factor, gap=self.args.gap
                ),
                # To be applied as a transformation (?)
                lambda cuts: cuts.merge_supervisions()
            ] + transforms

With --duration-factor 1 I see no error, but if it's >1 I get similar error:

  File "/icefall/egs/commonvoice/ASR/zipformer/model.py", line 323, in forward
    assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
AssertionError: (torch.Size([57, 2270, 80]), torch.Size([66]), 66)

But is not happening all the time, for example it works with 1.5, but not with 2.0

So the idea of --duration-factor is that if in the batch the biggest audio segment is of let's say 20s, It'll try to mix cuts up to 40s, right? Or is this affected by --max-duration?

I'm a bit lost with this :face_with_diagonal_mouth:

BTW all of this is for testing if it can fix some of the problems with Zipformer deletions https://github.com/k2-fsa/icefall/issues/1465

joazoa commented 2 months ago

@daniel-dona I tried concatenation a while ago, but it didn't resolve the problem with the deletions. (I don't have the patches anymore)

daniel-dona commented 2 months ago

@daniel-dona I tried concatenation a while ago, but it didn't resolve the problem with the deletions. (I don't have the patches anymore)

For me seems to be making a difference, at least at the sentence beginning or after a silence or background noise, but maybe this change depending on the dataset. More testing is needed...

danpovey commented 2 months ago

I think the cut concatenation implemented in asr_datamodule.py may be left over from a time when we were using ctc only, with a different ctc implementation. This is something that probably Fangjun or Piotr would nees to take a look at. There may be.some way in Lhotse to do this at a manifest level outside of the trainimg script (not sure of this though).

On Friday, May 31, 2024, Daniel Doña @.***> wrote:

@daniel-dona https://github.com/daniel-dona I tried concatenation a while ago, but it didn't resolve the problem with the deletions. (I don't have the patches anymore)

For me seems to be making a difference, at least at the sentence beginning or after a silence or background noise, but maybe this change depending on the dataset. More testing is needed...

— Reply to this email directly, view it on GitHub https://github.com/k2-fsa/icefall/issues/1426#issuecomment-2141875174, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO5ALICBTB5O2AB35PDZFBPTJAVCNFSM6AAAAABBAMXYKWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBRHA3TKMJXGQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

csukuangfj commented 2 months ago

I think the cut concatenation implemented in asr_datamodule.py may be left over from a time when we were using ctc only

Yes, you are right. It is for tdnn-lstm-ctc, where there is no attention.

The current code does not support cut concatenation for models with attention, e.g., zipformer, transformer, conformat, etc.

The main concern is that attention uses global information. If cut concatenation was used, then different utterances would attend to each other even if there are some paddings between each utterance during the concatenation, though we have never done experiments to verify how large the effect is.

pzelasko commented 2 months ago

@daniel-dona it's possible you need to also add fill_supervisions after merge_supervisions, but I'm not sure. Generally the aim of all this is to end up with each example having one supervision that spans the full cut, as the current training code is not ready for supervisions pointing to a subset of a cut. If you're running into these assertion errors, check the CutSet that was used to create the mini-batch before and after transforms to ensure that every cut has a single supervision that covers the full duration.

As a side note it may be interesting that flash attention has "varlen" kernels for "packed sequences" so technically it may be possible to run training with batch_size=1 and everything concatenated into a single sequence with appropriate masking. https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/flash_attn/flash_attn_interface.py#L845-L870