k2-fsa / fast_rnnt

A torch implementation of a recursion which turns out to be useful for RNN-T.
Other
139 stars 22 forks source link

Why T>=S constraint? #20

Closed BuaaAlban closed 1 year ago

BuaaAlban commented 1 year ago

code

Why do we need this constraint? In a regular rnnt, normally the joint may emit many blank symbol, and in this condition, T>S. But it's also possilble that S>T, e.g. we emit at least one non-blank symbols for each encoder frames.

Actually I have met this File "/rnnt_related/rnnt-mlperf-training/model_rnnt.py", line 203, in fast_joint simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( File "/anaconda3/envs/fast-rnnt/lib/python3.8/site-packages/fast_rnnt-1.2-py3.8-linux-x86_64.egg/fast_rnnt/rnnt_loss.py", line 282, in rnnt_loss_simple px, py = get_rnnt_logprobs( File "/anaconda3/envs/fast-rnnt/lib/python3.8/site-packages/fast_rnnt-1.2-py3.8-linux-x86_64.egg/fast_rnnt/rnnt_loss.py", line 149, in get_rnnt_logprobs assert T >= S, (T, S) AssertionError: (272, 274)

csukuangfj commented 1 year ago

In a regular rnnt

As you have mentioned, that is for regular RNN-T.


The version we are using is not regular. It has the same condition as CTC training, i.e., S <= T.

csukuangfj commented 1 year ago

Here is the paper about fast_rnnt:

https://arxiv.org/pdf/2206.13236.pdf

csukuangfj commented 1 year ago

Here is the code to filter data that don't satisfy S<=T in icefall: https://github.com/k2-fsa/icefall/blob/f13cf61b05432a989e6a42c95b843a56639bcbde/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L958

        # In ./conformer.py, the conv module uses the following expression
        # for subsampling
        T = ((c.num_frames - 1) // 2 - 1) // 2
        tokens = sp.encode(c.supervisions[0].text, out_type=str)

        if T < len(tokens):
            logging.warning(
                f"Exclude cut with ID {c.id} from training. "
                f"Number of frames (before subsampling): {c.num_frames}. "
                f"Number of frames (after subsampling): {T}. "
                f"Text: {c.supervisions[0].text}. "
                f"Tokens: {tokens}. "
                f"Number of tokens: {len(tokens)}"
            )
            return False
BuaaAlban commented 1 year ago

Thanks for your fast reply. I have tried to modify my code based on this example, I thinks it's a normal transducer. I can filter the data as you said to make it work. I just wonder why we has this limitation (for optimization? Actually I have read your paper yesterday but I didn't notice this condition, I will double check it), could I just comment this assert to make the pruned loss work just like the rnnt_loss (like in torchaudio or warp-transducer)

desh2608 commented 1 year ago

@BuaaAlban as you noted, this constraint is indeed not required for the "regular" RNNT topology. Only if you train with the "modified" topology, where you are constrained to emit exactly 1 symbol per time frame, will this constraint be required. We have a PR here (https://github.com/k2-fsa/k2/pull/1149) to remove this constraint from k2. I will also make a similar PR for fast_rnnt.

arkadyark commented 1 year ago

@desh2608 are you still planning to make this PR? This would be very useful for my work!

desh2608 commented 1 year ago

@arkadyark sorry I forgot to actually push the changes. BTW, I believe Dan fixed some OOM issues in the pruned transducer loss in k2, which hasn't yet been merged in fast_rnnt. So you may want to make those changes yourself.

arkadyark commented 1 year ago

Thanks! Which changes are you referring to? Looking through recent changes to rnnt_loss.py I don't see anything there.

desh2608 commented 1 year ago

Thanks! Which changes are you referring to? Looking through recent changes to rnnt_loss.py I don't see anything there.

Check https://github.com/k2-fsa/k2/pull/1177 and https://github.com/k2-fsa/k2/pull/1183

danpovey commented 1 year ago

Ah yes. Arkady, it would be great if you could make a PR to fast_rnnt with those changes, I had forgotten about that. If not LMK, I'll ask someone here.

arkadyark commented 1 year ago

I would love to contribute those back, but unfortunately there's a fairly involved open-source contribution process at my organization that would take a while, it'd probably be best to find someone else to do so.

However, I did test this out locally, and re-ran the benchmarking at https://github.com/csukuangfj/transducer-loss-benchmarking - the results look really good, peak memory usage goes from 3820 all the way down to 1182 (!), and from 2647 to 835 when sorting utterances. Step time (on my hardware) went from 343k to 280k us.

Pretty cool! Always gotta be careful with those torch.gathers.

arkadyark commented 1 year ago

Hey @danpovey , just wanted to follow up - is anybody able to make those changes here?

danpovey commented 1 year ago

@pkufool could you please have a look at this?

pkufool commented 1 year ago

@danpovey Yifan has already made PRs here https://github.com/danpovey/fast_rnnt/pull/26 and https://github.com/danpovey/fast_rnnt/pull/24 , you can merge it.

pkufool commented 1 year ago

closed by #29