awni / speech

A PyTorch Implementation of End-to-End Models for Speech-to-Text
Apache License 2.0
751 stars 176 forks source link

RNN Transducer training problem #27

Closed HawkAaron closed 6 years ago

HawkAaron commented 6 years ago

Hi, It seems that your implementation of RNN Transducer loss function is right. But when I train Graves2012 TIMIT, the loss decrease, but the PER increase, no matter how to adjust learning rate. ( If choose a small lr, the PER would be first decrease, then increase all the time. )

In your training procedure, the RNNT loss is exactly decreasing, but if you output the PER, it increasing! So what's wrong ?

HawkAaron commented 6 years ago

Hi, Recently, I write a greedy decoding for transducer:

    def greedy_decode(self, batch):
        x, y, x_lens, y_lens = self.collate(*batch)
        x = self.encode(x)[0]
        vy = autograd.Variable(torch.LongTensor([0]), volatile=True).view(1,1) # vector preserve for embedding
        y, h = self.dec_rnn(autograd.Variable(torch.zeros((1, 1, 256)), volatile=True)) # decode first zero 
        y_seq = []; logp = 0
        for i in x:
            out = self.fc1(i) + self.fc1(y[0][0])
            out = nn.functional.relu(out)
            out = self.fc2(out)
            out = F.log_softmax(out, dim=0)
            p, pred = torch.max(out, dim=0)
            pred = int(pred); logp += float(p)
            if pred != self.blank:
                y_seq.append(pred)
                vy.data[0][0] = pred # change pm state
                y = self.embedding(vy)
                y, h = self.dec_rnn(y, h)
        return [y_seq]

This code can be placed in your transducer_model.py file. Here just support one sequence per batch.

After decoding, PER is 24.4% in your training model ( maybe not converged).

It seems that the training PER calculated in the CTC way doesn't make sense.

Beam search would be much better.

smolendawid commented 6 years ago

It seems that the training PER calculated in the CTC way doesn't make sense. @HawkAaron what do you mean?

HawkAaron commented 6 years ago

@smolendawid what you said is exactly what I mean. Since the CTC transition topology is almost the same with HMM, the greedy search can be treated as viterbi decode. But the transition of RNN Transducer is two dimensional, decode any one of them doesn't make any sense.

Duum commented 6 years ago

@HawkAaron your greed decode implement is wrong. this will not find the best path Here is my implement you may need some change for your decoder.

HawkAaron commented 6 years ago

@Duum That is based on the assumption that one acoustic feature frame has at most one corresponding label.

To find more comparable paths, you can refer to my beam search implementation.