k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.1k stars 214 forks source link

Decode a Kaldi model with k2 #1207

Closed francisr closed 1 year ago

francisr commented 1 year ago

Hi, I'm trying to decode a Kaldi trained model with the k2 online decoder.

Here's what I'm doing:

make-pdf-to-tid-transducer final.mdl > pdf2tid.fst
fstcompose pdf2tid.fst HCLG.fst > PdfLG.fst
# Convert to k2 fsa and remove epsilons
fsa = k2.Fsa.from_str("PdfLG.fst", acceptor=False, openfst=True)
fsa = k2.remove_epsilon(fsa)
fsa = k2.connect(fsa)
fsa.aux_labels = fsa.aux_labels.remove_values_eq(0)
fsa = k2.arc_sort(fsa)

# Decrement pdf-ids because they're 1-indexed in pdf2tid.fst
idx = fsa.labels != -1
fsa.labels[idx] = fsa.labels[idx] - 1

# 
decode_fsas = k2.Fsa.from_fsas([fsa])
online_decoder = k2.OnlineDenseIntersecter(decode_fsas, 1, 11, 5, 200, 6000)
symtab = k2.SymbolTable.from_file("words.txt")
for key, mat in kaldi_io.read_mat_ark(loglikes_path):
    loglikes = - torch.from_numpy(mat).clone()
    loglikes = loglikes.unsqueeze(0)  
    print(loglikes.shape)  # (1, T, num_pdfs)
    break
supervision_segments = torch.tensor([[0, 0, loglikes.shape[1]]], dtype=torch.int32)
fsa_vec = k2.DenseFsaVec(loglikes, supervision_segments)
lat, decoder_state = online_decoder.decode(fsa_vec, [None])  # [None]  <- decoder_state for continuous Inference

best_path = k2.shortest_path(lat, use_double_scores=True)
for word_id in best_path.aux_labels.values:
    print(symtab.get(word_id.item()))

However the output is non-sense words, not sure what I'm doing wrong.

danpovey commented 1 year ago

Maybe the sign is wrong? IIRC k2 uses scores, not costs.

francisr commented 1 year ago

Hmm, quite silly, in this code I was negating the kaldi matrix when reading it into the loglikes tensor, because I knew about k2 using scores rather than costs, but it works when I read the loglikes as is.

Thanks.