Closed HawkAaron closed 6 years ago
Hi, Recently, I write a greedy decoding for transducer:
def greedy_decode(self, batch):
x, y, x_lens, y_lens = self.collate(*batch)
x = self.encode(x)[0]
vy = autograd.Variable(torch.LongTensor([0]), volatile=True).view(1,1) # vector preserve for embedding
y, h = self.dec_rnn(autograd.Variable(torch.zeros((1, 1, 256)), volatile=True)) # decode first zero
y_seq = []; logp = 0
for i in x:
out = self.fc1(i) + self.fc1(y[0][0])
out = nn.functional.relu(out)
out = self.fc2(out)
out = F.log_softmax(out, dim=0)
p, pred = torch.max(out, dim=0)
pred = int(pred); logp += float(p)
if pred != self.blank:
y_seq.append(pred)
vy.data[0][0] = pred # change pm state
y = self.embedding(vy)
y, h = self.dec_rnn(y, h)
return [y_seq]
This code can be placed in your transducer_model.py
file. Here just support one sequence per batch.
After decoding, PER is 24.4% in your training model ( maybe not converged).
It seems that the training PER calculated in the CTC way doesn't make sense.
Beam search would be much better.
It seems that the training PER calculated in the CTC way doesn't make sense. @HawkAaron what do you mean?
@smolendawid what you said is exactly what I mean. Since the CTC transition topology is almost the same with HMM, the greedy search can be treated as viterbi decode. But the transition of RNN Transducer is two dimensional, decode any one of them doesn't make any sense.
@HawkAaron your greed decode implement is wrong. this will not find the best path Here is my implement you may need some change for your decoder.
@Duum That is based on the assumption that one acoustic feature frame has at most one corresponding label.
To find more comparable paths, you can refer to my beam search implementation.
Hi, It seems that your implementation of RNN Transducer loss function is right. But when I train Graves2012 TIMIT, the loss decrease, but the PER increase, no matter how to adjust learning rate. ( If choose a small lr, the PER would be first decrease, then increase all the time. )
In your training procedure, the RNNT loss is exactly decreasing, but if you output the PER, it increasing! So what's wrong ?