1ytic / warp-rnnt

CUDA-Warp RNN-Transducer
MIT License
211 stars 41 forks source link

question about the rnnt loss arguments #18

Open songtaoshi opened 3 years ago

songtaoshi commented 3 years ago
        log_probs (torch.FloatTensor): Input tensor with shape (N, T, U, V)
            where N is the minibatch size, T is the maximum number of
            input frames, U is the maximum number of output labels and V is
            the vocabulary of labels (including the blank).
        labels (torch.IntTensor): Tensor with shape (N, U-1) representing the
            reference labels for all samples in the minibatch.

Hi, I am confused about the labels, why the shape should be U-1,
<eos> should not be included in the labels ? @1ytic

songtaoshi commented 3 years ago

and also I see the training code, the LM input ys is the same as the target ys.
This should not be +text as input; text+ as output?

1ytic commented 3 years ago

If I remember correctly, U includes "empty" output, very similar to the first element in the scoring matrix when you align two sequences, for example like this https://en.wikipedia.org/wiki/Smith–Waterman_algorithm

zhaoyang9425 commented 3 years ago
        log_probs (torch.FloatTensor): Input tensor with shape (N, T, U, V)
            where N is the minibatch size, T is the maximum number of
            input frames, U is the maximum number of output labels and V is
            the vocabulary of labels (including the blank).
        labels (torch.IntTensor): Tensor with shape (N, U-1) representing the
            reference labels for all samples in the minibatch.

Hi, I am confused about the labels, why the shape should be U-1, <eos> should not be included in the labels ? @1ytic

I have the same doubt, do you understand it? Why the shape of labels be U-1?

NiHaoUCAS commented 2 years ago

I guess, U = len() + len(labels), len() = 1. shouldn't in the labels, but in the encoder logits