k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
884 stars 286 forks source link

`fast_beam_search()` returns a non-differentiable lattice + MWER training #1168

Open desh2608 opened 1 year ago

desh2608 commented 1 year ago

In the fast_beam_search() method, the lattice is eventually generated at: https://github.com/k2-fsa/icefall/blob/ffe816e2a8314318a4ef6d5eaba34b62b842ba3f/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py#L546. We do not explicitly pass logprobs here since these are stored on the arcs of the fsa during the decoding process.

If we look at the definition of this method in k2 here, we can see that the arc scores are made differentiable only if logprobs is explicitly passed to this function.

As a result, the generated lattice has non-differentible arc scores.

pkufool commented 1 year ago

The differentiable lattice is for MWER loss (https://github.com/k2-fsa/k2/blob/42e92fdd4097adcfe9937b4d2df7736d227b8e85/k2/python/k2/mwer_loss.py). For decoding, I think it is no need to return a differentiable lattice.

See also this PR https://github.com/k2-fsa/k2/pull/1094

desh2608 commented 1 year ago

Thanks, I'm indeed trying to use it for MWER loss.

danpovey commented 1 year ago

@pkufool I am not 100% following here.. but is it possible to add an option to fast_beam_search() to make it differentiable?

desh2608 commented 1 year ago

@danpovey It makes the lattice scores differentiable if you explicitly pass in the logprobs, which can be done similar to this code from @glynpu.

danpovey commented 1 year ago

OK, but I am wondering what is the downside of making fast_beam_search() just do this automatically if the input lattice has scores.requires_grad = True.

pkufool commented 1 year ago

OK, but I am wondering what is the downside of making fast_beam_search() just do this automatically if the input lattice has scores.requires_grad = True.

Will to see if we can simplify the usage.

desh2608 commented 1 year ago

I am trying to fine-tune an RNN-T model with MWER loss. I created an LG decoding graph, and obtained lattices using fast_beam_search():

lattice = fast_beam_search(
    model=self,
    decoding_graph=decoding_graph,
    encoder_out=encoder_out,
    encoder_out_lens=x_lens,
    beam=4,
    max_states=64,
    max_contexts=8,
    temperature=1.0,
    ilme_scale=0.0,
    allow_partial=True,
    blank_penalty=0.1,
    requires_grad=True,
)

(Note that I added a requires_grad argument to fast_beam_search which just ensures that all arc scores are tracked.)

I obtained ref_texts from the lexicon's symbol table as:

oov_id = word_table["<unk>"]
y = []
for text in texts:
    word_ids = []
    for word in text.split():
        if word in word_table:
            word_ids.append(word_table[word])
        else:
            word_ids.append(oov_id)
    y.append(word_ids)

I then call the k2.mwer_loss with the lattice and ref_texts:

with torch.cuda.amp.autocast(enabled=False):
    mbr_loss = k2.mwer_loss(
        lattice=lattice,
        ref_texts=y,
        nbest_scale=0.5,
        num_paths=200,
        temperature=1.0,
        reduction="sum",
        use_double_scores=True,
    )

However, the MWER loss seems to be increasing during training:

image

Does the above strategy look okay or am I missing something?

danpovey commented 1 year ago

A loss that increases and then stays roughly constant is what you might expect if parameter noise was a problem (due to too-high learning rate). Perhaps the learning rate is higher than the final learning rate of the base system? Also the MWER loss is much noisier than the regular loss, so might require a lower learning rate.

desh2608 commented 1 year ago

I realized that I was "continuing training" from the last checkpoint, instead of initializing a new model's parameters from the pre-trained checkpoint, due to which the optimizer/scheduler states were carrying over. I fixed this and reduced LR to 0.0004. I also set nbest_scale=0.1, temperature=0.5 in the mwer_loss to have more unique paths in the lattice (although I'm not sure if that would be useful).

Now I see that the training loss is going down, at least until now (I'm only at a few hundred training steps). Training is only about 3x slower than regular RNN-T, which is not bad I suppose.

image
desh2608 commented 1 year ago

I have been working on MWER training for transducers. The training loss improves, but the validation loss gets worse, and I find that the resulting WER is also much worse than the model I initialized with. Here are the train/val curves: https://tensorboard.dev/experiment/L8QIjliVSQm6kRCv1Cghmw/#scalars

Here are the decoding results on TED-LIUM dev using greedy search. The first model was trained with pruned_rnnt loss, and then I used this model to initialize the MWER training (second row).

training ins del sub WER
rnnt 1.17 2.11 4.43 7.71
mwer 0.69 10.42 7.21 18.33

I used a base LR of 0.0004 for MWER training. For lattice generation, I use beam=4, max_states=64, max_contexts=8, blank_penalty=0.1 and for the MWER loss computation, I used num_paths=200, nbest_scale=0.1, temperature=0.5. I am wondering which of these values should I change to avoid over-fitting?

(BTW, it seems most of the sentences have some del/sub errors at the beginning of the sentence.)