baidu-research / warp-ctc

Fast parallel CTC.
Apache License 2.0
4.06k stars 1.04k forks source link

did not converge when changing from tf.nn.ctc_loss to warpctc #153

Closed ArtemisZGL closed 4 years ago

ArtemisZGL commented 4 years ago

I use the openSeq2Seq to do the asr project. and i change the tf.nn.ctc_loss to warpctc in the code, setting the blank_label parameter to the size of vocabulary . After that, it can train normally without error, but i found that the training loss can not converge, and the prediction is very strange.

[1,0]: Sample target: here is also seen an american indian a cowboy a merchant and an artisan an american flag is borne aloft while four west point cadets suggest training and leadership women relief workers of all kinds are seen [1,0]: Sample prediction: e e e e e e te e e e e te e e e te e e e e e

when i use tf.nn.ctc_loss is should be like(not so many blank and the letter e):

[1,0]:*** Sample prediction: xklx wlwpzi rfrau wpd'pfo wyp lprpozefaoziwrwripztsowpgvtvhvlbmzlxeqvoratrzpwo'pt peplywzlbw'cpa'x

here is the code i change

    # total_loss = tf.nn.ctc_loss(
    #     labels=dense_to_sparse(tgt_sequence, tgt_length),
    #     inputs=logits,
    #     sequence_length=src_length,
    #     ignore_longer_outputs_than_inputs=True,
    # )

    activations = logits
    flat_labels = tf.reshape(tgt_sequence, [-1])
    label_lengths = tgt_length

    total_loss = ctc(activations=activations,
                flat_labels=flat_labels,
                label_lengths=label_lengths,
                input_lengths=src_length,
                blank_label=28)

so i wonder how to fix this problem?

ArtemisZGL commented 4 years ago

i found it‘s the label get wrong, without cutting the padding in tgt_sequence.