k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.1k stars 214 forks source link

Question about frame number attribute on lattice #1229

Closed desh2608 closed 1 year ago

desh2608 commented 1 year ago

Suppose I have a recognition lattice (for example, obtained from fast_beam_search). I am trying to add an attribute to the lattice FSA which can be interpreted as the timestep of the arc. For example, arcs at time-step "t" will have this attribute set to "t" (which is an int or a float, depending on if we want to store the frame number or time (w.r.t. some frame shift).

What would be the best way to do this? I imagine it would require creating a RaggedTensor of the same shape as the aux_labels, but I am not sure how to get the time-step corresponding to the arc.

csukuangfj commented 1 year ago

@desh2608 sorry for the late reply (I just checked my email).

Maybe @pkufool can help you with it.

pkufool commented 1 year ago

@desh2608 I think if we have the arc_map_b (arc map to log_prob), it is not hard to get the t for each arc, here is the reference of intersect_dense_pruned https://github.com/k2-fsa/k2/blob/821ebc378e7fb99b8adc81950227963332821e01/k2/python/k2/autograd.py#L593-L606

desh2608 commented 1 year ago

@pkufool Thanks for the pointer! I'll try to work out the details and ask you if I get stuck.

desh2608 commented 1 year ago

@csukuangfj @pkufool I have a question about manipulating RaggedTensor. Here is the situation:

Suppose I obtain a lattice using fast_beam_search_nbest_LG, so the labels are token IDs and the aux_labels are word IDs. I then select some random paths from the lattice using k2.random_paths, and obtain the word ID sequences as RaggedTensor from the aux_label (see code below):

saved_scores = lattice.scores.clone()
lattice.scores *= nbest_scale
path = k2.random_paths(
    lattice, num_paths=num_paths, use_double_scores=use_double_scores
)
lattice.scores = saved_scores
word_seq = lattice.aux_labels.index(path)

At this point, word_seq has axes [utt][path][word_id]. Now, I want to compute a RaggedTensor delays, which basically gives the timestamp for each word (based on some frame shift). A very naive way to do this may be as follows:

def compute_delays(word_seq: k2.RaggedTensor, frame_shift: float) -> k2.RaggedTensor:
    """Compute delays for a word sequence.

    Args:
      word_seq:
        It is a k2.RaggedTensor with 4 axes [utt][path][word_id].
        It contains word IDs. Note that it also contains 0s and -1s.
        The last entry in each sublist is -1.
      frame_shift:
        The frame shift in seconds.

    Returns:
      Return a k2.RaggedTensor with 3 axes [utt][path][delay].
    """
    assert word_seq.num_axes == 4
    # The last axis contains either a word ID or empty list.
    assert word_seq.values.dtype == torch.int32

    delays = []
    word_seq_list = word_seq.tolist()
    for b in word_seq_list:
        delays_b = []
        for path in b:
            delays_path = []
            for i, word in enumerate(path):
                if len(word) == 0:
                    delays_path.append([])
                else:
                    delays_path.append([i * frame_shift])
            delays_b.append(delays_path)
        delays.append(delays_b)
    delays = k2.RaggedTensor(delays, dtype=torch.float32, device=word_seq.device)
    return delays

Eventually, I can do the following to remove the empty axes for both word_seq and delays:

word_seq = word_seq.remove_axis(word_seq.num_axes - 2)
delays = delays.remove_axis(delays.num_axes - 2)

Of course, this method is very inefficient because of all the nested loops.

My question: is there a way to achieve this more efficiently using existing k2 methods for RaggedTensors?

csukuangfj commented 1 year ago

https://github.com/k2-fsa/k2/blob/b835546b6005d243865e0acc3d29bd9c51670b1e/k2/python/k2/rnnt_decode.py#L288-L289

how about using torch.arange() to replace log_probs.reshape(-1) and assigning an index to each arc.

And you can follow https://github.com/k2-fsa/k2/blob/b835546b6005d243865e0acc3d29bd9c51670b1e/k2/python/k2/rnnt_decode.py#L296 to add a new attribute, e.g., fsa.frames.

You can index fsa.frames with path returned by

path = k2.random_paths(
    lattice, num_paths=num_paths, use_double_scores=use_double_scores
)
desh2608 commented 1 year ago

Thanks for the suggestion. Once I have the indexed frames, is there a way to remove the ones corresponding to empty labels from the word_seq (i.e., get frames corresponding to the word_seq obtained from word_seq.remove_axis(word_seq.num_axes - 2))?

csukuangfj commented 1 year ago

Once I have the indexed frames, is there a way to remove the ones corresponding to empty labels from the word_seq

Yes. Before using path to index fsa.frames, you can set entries of fsa.frames whose arc's aux_label is 0 to a special value, e.g., -100. After indexing with path, you can remove values that are equal to -100. (I think there is a method of ragged tensor called remove values eq).

desh2608 commented 1 year ago

https://github.com/k2-fsa/k2/blob/b835546b6005d243865e0acc3d29bd9c51670b1e/k2/python/k2/rnnt_decode.py#L288-L289

how about using torch.arange() to replace log_probs.reshape(-1) and assigning an index to each arc.

And you can follow

https://github.com/k2-fsa/k2/blob/b835546b6005d243865e0acc3d29bd9c51670b1e/k2/python/k2/rnnt_decode.py#L296

to add a new attribute, e.g., fsa.frames. You can index fsa.frames with path returned by

path = k2.random_paths(
    lattice, num_paths=num_paths, use_double_scores=use_double_scores
)

It seems creating fsa.frames this way makes it a Tensor instead of a RaggedTensor.

Edit: I guess I can get a RaggedTensor using:

frames = k2.RaggedTensor(fsa.aux_labels.shape, fsa.frames)

^^This does not work. The tensor shape and the RaggedShape are not compatible.

csukuangfj commented 1 year ago

Please wait, let me have a think.

csukuangfj commented 1 year ago

You can first turn the ragged tensor into a 1-d tensor by using its values attribute

indexes = path.values
selected = fsa.frames[indexes]

and then you use path.shape and selected to construct a new ragged tensor.


[EDITED]: You may need to convert indexes from torch.int32 to torch.long.

desh2608 commented 1 year ago

@csukuangfj Thanks, that worked! How can I find the indices for the aux_labels whose value is 0 (i.e. it is empty)? At the moment, I have:

word_seq.values.shape
>>> torch.Size([173279])
frames.values.shape
>>> torch.Size([2032200])
csukuangfj commented 1 year ago

How can I find the indices for the aux_labels whose value is 0

First, we get the number of aux_labels per arc:

num_aux_labels_per_arc = lattice.aux_labels.row_splits(1)[1:] - lattice.aux_labels.row_splits(1)[:-1]

Note that num_aux_labels_per_arc is a 1-d tensor.

Second, we set the frames of those arcs whose num_aux_labels_per_arc is 0 to an exceptional value, e.g., -100

lattice.frames[num_aux_labels_per_arc == 0] = -100

Third, obtain the frames by indexing path

indexes = path.values
selected = lattice.frames[indexes]

Last, we remove the special values from selected by remove_values_eq

filtered = selected.remove_values_eq(-100)
desh2608 commented 1 year ago

Thanks, Fangjun! This works perfectly. I think I understand RaggedTensor operations a little more now :)

(Minor edit: lattice.aux_labels.row_splits(1)[1:] --> lattice.aux_labels.shape.row_splits(1)[1:], since row_splits() is a method of RaggedShape)