k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

Pseudocode for confidence estimation #175

Open danpovey opened 3 years ago

danpovey commented 3 years ago

Guys, I am creating this issue just as a way to show this pseudocode, it demonstrates parts of where we are going with k2. One feature is that the fsa objects will have fields 'per_arc' which contain arbitrary tensors whose first dimension is the values.Dim() of the arcs. (They are accessed as if they were class members but are really members of a dict). When we do operations on FSAs, these per_arc quantities are propagated. (This is easy given the arc_map objects).

The python-level fsa object is going to be a more complicated object than the C++ one. You can think of it as containing the C++ object as just one member (maybe we could make the C++ object a member called arcs or something).

loglikes = eval_nnet(minibatch.data)
dense_fsas = k2fsa.dense_fsas(loglikes)
lattices = k2fsa.pruned_compose(dense_fsas, decoding_graph)
if train:
   oracle = k2fsa.pruned_compose(dense_fsas, minibatch.supervision_graphs)
   # `oracle` inherited per_arc.word_labels from `supervision_graphs`
    oracle.per_arc.frames = dense_fsas.get_idx01(oracle.src_arcs_a())
    oracle = k2fsa.best_path(oracle)
    ref_phone_labels = torch.zeros([dense_fsas.shape.TotSize(1)], dtype=torch.long)
    ref_word_labels = ref_phone_labels.clone()
    ref_phone_labels[oracle.per_arc.frames] = oracle.get_labels()
    # note: word_labels were propagated from `supervision_graphs.per_arc.word_labels`
    ref_word_labels[oracle.per_arc.frames] = oracle.per_arc.word_labels

post = k2fsa.arc_post(lattices)
lattices.per_arc.frames = dense_fsas.get_idx01(lattices.src_arcs_a())
      # gives unique identifiers for
      # frames of input, different for different utterances.
arc_post = k2fsa.arc_post(lattices)

num_arcs = lattices.arcs.shape[0]
phones_one_hot = torch.zeros(num_arcs, num_phones)
phones_one_hot[range(num_arcs), lattices.get_labels()] = 1.0
# initial features are (acoustic, LM, posterior, LLR vs. best path, one-hot phone labels
lattices.per_arc.feats = torch.cat(torch.stack(lattices.inputs.scores_a(),
      lattices.inputs.scores_b(), log(arc_post), k2fsa.llr(lattices)),
               phones_one_hot)
# augment with the features above but averaged on each frame, over all paths.
lattices.per_arc.feats = torch.cat(lattices.per_arc.feats,
                           pool_and_redistribute(lattices.per_arc.feats,
                weights=arc_post, buckets=lattices.per_arc.frames))

if train:
   lattices.per_arc.phone_correct = (lattices.get_labels() == ref_phone_labels[lattices.per_arc.frames])
   # note: word_labels were propagated from `decoding_graph.per_arc.word_labels`
   lattices.per_arc.word_correct = (lattices.per_arc.word_labels == ref_word_labels[lattices.per_arc.frames])

# Convert the lattices into n-best lists
# such that each arc appears in at least one linear sequence.
nbest = k2fsa.covering_nbest(lattices)

(feats, frames, phone_correct, word_correct) = \
   k2.ragged_to_tensor(nbest.arcs.shape, nbest.per_arc.feats,
                       nbest.per_arc.frames, nbest.per_arc.phone_correct,
                       nbest.per_arc.word_correct)

 loglikes = dense_fsas.loglikes[frames]

 (word_confidence, phone_confidence) = confidence_model(loglikes, scores)

 if train:
    weights = 1.0 / (1.0 + nbest.num_paths[nbest.per_fsa.src_indexes])
    objf += confidence_objf(word_confidence, word_correct, weights) + \
            confidence_objf(phone_confidence, phone_correct, weights)

 (nbest.per_arc.phone_confidence,
  nbest.per_arc.word_confidence) = k2.ragged_to_tensor_inv(nbest.arcs.shape,
                  word_confidence, phone_confidence)

 recombined = k2fsa.union(nbest, row_ids=nbest.per_fsa.src_indexes)
 # Use phone and word confidences as the scores.  Note, these include
 # confidences for epsilons so we won't just delete everything.
 recombined.arcs.scores[:] = (recombined.per_arc.phone_confidence +
                              3 * recombined.per_arc.word_confidence)
 result = k2fsa.best_path(recombined)
 # can access result.per_arc.{phone,word}_confidence and the like...
csukuangfj commented 3 years ago

@danpovey

dense_fsas = k2fsa.dense_fsas(loglikes)

Are you implementing DenseFsaVec? How is a DenseFsaVec constructed from loglikes?

danpovey commented 3 years ago

Look in fsa.h. It is just a struct. Needs a constructor, that's all.

danpovey commented 3 years ago

Note, something would need to be done in Python in general, when constructing it, to add the - infinitys in the right place. (This needs to be done at the Python level, using the base toolkit, so that autograd will work). There are 2 cases for construction: regular and irregular. In the irregular case, please note that Lhotse keeps track of the supervision start/end times separately from the features start/end times, so even if the features have all the same length the supervisions may not. We probably shouldn't interface directly with lhotse but should give it some reasonably usable interface. In general, each sequence (i.e. at the output of the nnet, a sequence of loglikelihoods) will be associated with zero or more supervision objects, possibly overlapping in time.

csukuangfj commented 3 years ago

Look in fsa.h. It is just a struct. Needs a constructor, that's all.

DenseFsaVec contains only two members: RaggedShape and Array2<float>. Is it equivalent to the emssion graph in the Figure 2(d) of this paper?


lattices = k2fsa.pruned_compose(dense_fsas, decoding_graph)

Since dense_fsas contains ragged shape instead of FSAs, do we construct the FSAs dynamically during intersection?

danpovey commented 3 years ago

Look in fsa.h. It is just a struct. Needs a constructor, that's all.

DenseFsaVec contains only two members: RaggedShape and Array2. Is it equivalent to the emssion graph in the Figure 2(d) of this paper https://arxiv.org/pdf/2010.01003.pdf?

lattices = k2fsa.pruned_compose(dense_fsas, decoding_graph)

Since dense_fsas contains ragged shape instead of FSAs, do we construct the FSAs dynamically during intersection?

Yes, comments there should explain it, but it contains the emission probabilities (as matrices) of a number of pieces of supervised audio that may not all be the same size. It also contains -1's at specific locations. The matrices are all appended into one Array2, and the ragged matrix says where each piece of audio starts and ends in the matrix.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/k2-fsa/k2/issues/175#issuecomment-703666677, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAZFLO7WRTOSU445L4KQZNLSJHJLVANCNFSM4RW3HJFQ .