k2-fsa / snowfall

Moved to https://github.com/k2-fsa/icefall
Apache License 2.0
143 stars 42 forks source link

Trace frame-level scores using lattice #248

Open zrsjta opened 3 years ago

zrsjta commented 3 years ago

Hi, team!

I have encountered a problem with k2 in my code. Below is the description of this problem.

For a nnet_output with shape [B, T, D], I am trying to calculate the scores on a graph (MMI numerator or denominator) with any prefix segment of nnet_output, namely nnet_output[:, :t, :], where t is any index smaller than T (the total length in time axis).

Currently, I implement it by a loop. But this leads to much computation. My code is below

        graph, _ = self.graph_compiler.compile(texts, self.P, replicate_den=True)
        T = x.size()[1]
        tot_scores = []
        for t in range(T, 0, -1):
            supervision = torch.Tensor([[0, 0, t]]).to(torch.int32) # [idx, start, length]
            dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
            lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0)
            frame_score = lats.get_tot_scores(log_semiring=True, use_double_scores=True)
            tot_scores.append(frame_score)
        tot_scores = torch.cat(tot_scores)

Could these scores be calculated by parsing the lats obtained from the whole nnet_output, which means we can calculate them with only one k2.intersect_dense? Approximation is also ok for me.

Thanks for your help ! :)

zrsjta commented 3 years ago

Additional information:

  1. we do NOT need this process differentiable.
  2. It seems that the scores of each state and arc are accessible. Could we solve this problem by those scores on lattices?
csukuangfj commented 3 years ago

Please refer to the help doc of k2.intersect_dense: https://k2-fsa.github.io/k2/python_api/api.html#k2.intersect_dense

There are two extra optional arguments:

def intersect_dense(a_fsas: Fsa,
                    b_fsas: DenseFsaVec,
                    output_beam: float,
                    a_to_b_map: Optional[torch.Tensor] = None,
                    seqframe_idx_name: Optional[str] = None,
                    frame_idx_name: Optional[str] = None) -> Fsa:

You can use

lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0, seqframe_idx_name='seqframe', frame_idx_name='frame')

After the above call, the resulting lats has two extra 1-D tensor attributes: seqframe and frame. That is, you can access them using

lats.seqframe
lats.frame

In your case, the supervision contains only one utterance, so lats.seqframe and lats.frame should be the same. lats.frame contains values in the range from 0 to T and it contains as many entries as lats.num_arcs.

You can invoke k2.intersect_dense only once by feeding T frames and then call lats.get_forward_scores to get the forward scores of each state. After this, you can use lats.frame to identify which states corresponding to the t-th frame, and then use log_sum_exp to sum the scores of these states.

[EDITED]: The whole process is also differentiable.

zrsjta commented 3 years ago

Thanks for the reply

It seems that lats.frame represents the frame index of each arc. But I'm not sure we can recover the mapping from state to frame by lats.frame only.

Previously, I did not notice the API lats.frame. However, I have recovered this mapping in another way (see code below)

def trace_frame(lats):
    arcs = lats[0].as_dict()['arcs']

    frame2state = []
    prev_buf, cur_buf = [0], []

    for arc in arcs:
        f, t, _, _ = arc
        f, t = int(f), int(t)

        if f in prev_buf:
            if not t in cur_buf:
                cur_buf.append(t)

        else:
            frame2state.append(prev_buf)
            prev_buf = cur_buf
            cur_buf = [t]

    frame2state.append(prev_buf) # last frame
    frame2state.append([t]) # final state
    return frame2state

After that, I try to compute the frame-level score by frame2state, lats.get_forward_scores and then use log_sum_exp, but there is a large gap between the results obtained in this way and the way with a loop (as I presented first). Below is the code

        _, den = self.graph_compiler.compile(texts, self.P, replicate_den=True)

        T = x.size()[1]
        print(T)
        den_scores = []
        for t in range(T, 0, -1):
            supervision = torch.Tensor([[0, 0, t]]).to(torch.int32) # [idx, start, length]
            dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
            den_lats = k2.intersect_dense(den, dense_fsa_vec, output_beam=10.0)
            den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
            den_scores.append(den_tot_scores)
        den_scores = torch.cat(den_scores).unsqueeze(0) # [T] -> [B, T]
        print("den score computed from previous version: ", den_scores)

        # new implementation
        supervision = torch.Tensor([[0, 0, T]]).to(torch.int32)
        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
        den_lats = k2.intersect_dense(den, dense_fsa_vec, output_beam=10.0,\
                   seqframe_idx_name='seqframe', frame_idx_name='frame')

        frame2state = trace_frame(den_lats)
        den_forward_scores = den_lats.get_forward_scores(log_semiring=True, use_double_scores=True)
        assert len(frame2state) == T + 2 # extra start and end state
        den_scores_new = []
        for t in range(T, 0, -1):
            states = frame2state[t]
            den_score = torch.logsumexp(den_forward_scores[states], dim=0)
            den_scores_new.append(den_score)
        den_scores_new = torch.stack(den_scores_new, dim=-1)
        print("den score computed from new version: ", den_scores_new)
        print("diffence of the two version: ", den_scores_new - den_scores)

And the results of the two versions cannot match.

den score computed from previous version:  tensor([[783.3993, 781.7219, 780.0333, 778.3439, 776.5803, 774.7105, 772.7158,
         770.5528, 768.1100, 762.4893, 754.7248, 746.1509, 736.5164, 726.7509,
         716.2331, 706.6452, 697.4263, 686.9293, 681.2793, 674.0765, 664.6529,
         657.2257, 650.4535, 643.3796, 634.3289, 628.0670, 619.4294, 609.9516,
         600.0689, 594.3592, 590.8794, 585.2605, 581.2088, 578.4872, 572.3086,
         571.1170, 566.4870, 558.6632, 552.9537, 545.4799, 541.9318, 530.6125,
         521.3177, 511.8524, 502.7617, 492.8600, 487.0916, 479.4112, 471.2255,
         462.2586, 454.3531, 442.9971, 436.3475, 428.7141, 420.9272, 413.6736,
         406.7801, 397.0022, 391.5931, 386.9415, 379.1681, 370.1025, 361.0875,
         351.7468, 341.6609, 333.3474, 325.8754, 318.5557, 309.9469, 303.1736,
         296.7173, 286.5587, 280.5947, 273.7012, 266.3182, 258.1839, 250.6710,
         242.8872, 236.6474, 229.7769, 221.9658, 213.3599, 204.9579, 192.9310,
         186.8552, 179.3116, 170.0203, 160.0337, 150.8213, 142.6257, 134.6330,
         127.6783, 116.1433, 103.3247,  91.5196,  80.5909,  69.7985,  59.0771,
          48.1386,  36.4054,  23.1133,  11.1909,  -2.9392]],
       dtype=torch.float64)
den score computed from new version:  tensor([783.3993, 781.7252, 780.0382, 778.3510, 776.5909, 774.7267, 772.7416,
        770.5969, 768.1950, 762.6968, 755.2699, 747.3608, 738.8629, 730.4765,
        725.2539, 716.5815, 707.1374, 698.3506, 691.5351, 679.1446, 670.6063,
        663.6814, 656.9236, 651.3605, 647.0456, 640.6885, 626.9087, 617.5355,
        610.5875, 604.0881, 598.1472, 590.7349, 588.7432, 585.2757, 583.0598,
        581.2615, 570.5313, 566.4218, 561.7389, 557.9919, 554.6887, 543.3505,
        529.1934, 521.2846, 512.1856, 504.0500, 500.0110, 490.2833, 477.4947,
        471.4613, 463.3775, 455.1872, 449.6866, 438.6009, 426.2327, 421.9648,
        415.6079, 408.7287, 403.9518, 395.7307, 387.3507, 378.8256, 370.0204,
        361.4262, 355.0736, 346.0511, 331.8522, 322.9216, 317.4937, 311.1927,
        304.5424, 297.9715, 292.8139, 282.6298, 272.3836, 264.5364, 257.1948,
        250.8738, 245.0846, 235.5465, 227.8981, 221.6154, 213.5615, 205.2312,
        199.0696, 187.7953, 177.1561, 169.5724, 160.6012, 151.7466, 146.6541,
        140.5165, 129.1156, 116.3533, 104.6344,  93.8204,  83.1367,  72.5044,
         61.6243,  49.9238,  36.6502,  24.7326,  10.5949], dtype=torch.float64)
diffence of the two version:  tensor([[    0.0000,     0.0034,     0.0049,     0.0071,     0.0106,     0.0162,
             0.0258,     0.0441,     0.0850,     0.2075,     0.5451,     1.2099,
             2.3464,     3.7256,     9.0209,     9.9363,     9.7111,    11.4214,
            10.2558,     5.0681,     5.9534,     6.4558,     6.4701,     7.9809,
            12.7167,    12.6216,     7.4792,     7.5838,    10.5185,     9.7289,
             7.2677,     5.4744,     7.5344,     6.7885,    10.7512,    10.1445,
             4.0443,     7.7586,     8.7853,    12.5120,    12.7569,    12.7380,
             7.8757,     9.4322,     9.4238,    11.1901,    12.9194,    10.8722,
             6.2692,     9.2026,     9.0244,    12.1901,    13.3391,     9.8868,
             5.3055,     8.2912,     8.8277,    11.7265,    12.3586,     8.7891,
             8.1826,     8.7230,     8.9329,     9.6794,    13.4127,    12.7037,
             5.9769,     4.3660,     7.5468,     8.0192,     7.8250,    11.4128,
            12.2193,     8.9286,     6.0653,     6.3525,     6.5239,     7.9867,
             8.4372,     5.7696,     5.9323,     8.2554,     8.6036,    12.3002,
            12.2144,     8.4837,     7.1358,     9.5387,     9.7798,     9.1209,
            12.0211,    12.8382,    12.9724,    13.0286,    13.1148,    13.2295,
            13.3382,    13.4273,    13.4856,    13.5184,    13.5370,    13.5417,
            13.5341]], dtype=torch.float64)

Currently, I don't know which version is correct. The motivation of this is to compute the probability P(W|O_{1:t}) according to the definition of LF-MMI but only use the first several frames.

csukuangfj commented 3 years ago
arcs = lats.arcs.values()[:, :2]
# arcs is a 2-D torch.int32 tensor

for idx, (src, dst) in enumerate(arcs.tolist()):
  # note src is not used and you can replace it with an underscore _
   frame_idx = lats.frame[idx]
   # now you konw the state `dst` belongs to the frame `frame_idx`
  # You can add the forward_score of this state to a list corresponding to the frame `frame_idx`
  # 
  # Caution: You have to avoid adding `dst` state multiple times 

# At this point, you know the states corresponding to each frame, you can use `log-sum-exp` to combine them.
csukuangfj commented 3 years ago

But I'm not sure we can recover the mapping from state to frame by lats.frame only.

As I posted above, you can iterate over the arcs; for each arc, you can get its frame_idx and dest_state. This dest_state belongs to the frame frame_idx.

If you have multiple utterances, then you have to use seqframe_idx.

csukuangfj commented 3 years ago

Note: I have re-edited the demo code.

csukuangfj commented 3 years ago

Previously, I did not notice the API lats.frame. However, I have recovered this mapping in another way (see code below)

Please use a small lats to verify that your code is correct. (You can print the resulting lats, its frame, its states, the return value frame2state to check the correctness of your code)

~Looks like your frame2state is a 1-d list, which is not correct.~

A frame can correspond to multiple states, while a state belongs to only one frame.

csukuangfj commented 3 years ago

You can note down which state belongs to which frame. For example

frame_0_states = [1, 2, 3]
frame_1_states = [4, 5, 6, 7, 8, 9]
....

You can get the total scores for frame 0 using

forward_scores = lats.get_forward_scores(use_double_scores=True, log_semiring=True)

frame_0_tot_scores = forward_scores[frame_0_states].exp().sum().log()

Note: For the last frame T-1, you have to consider the scores on the arcs entering the final state if those scores are not zero.

zrsjta commented 3 years ago

@csukuangfj As suggested, I restrict the nnet_output to 3 frames and checked the frame2state. I suppose my frame2state is correct.

den score computed from previous version:  tensor([[23.1133, 11.1909, -2.9392]], dtype=torch.float64)
arcs: tensor([[          0,           1,           0,  1083890428],
        [          0,           2,         147,  1093236685],
        [          1,          13,         147,  1093883163],
        [          2,           3,          21, -1083650674],
        [          2,           4,          35,  1042883936],
        [          2,           5,          43, -1103412784],
        [          2,           6,          48, -1091424240],
        [          2,           7,         111, -1081929866],
        [          2,           8,         112, -1090075914],
        [          2,           9,         114,  1050298232],
        [          2,          10,         115, -1070879172],
        [          2,          11,         140, -1089980508],
        [          2,          12,         141, -1078741436],
        [          2,          13,         147,  1096957993],
        [          2,          14,         153, -1097704528],
        [          2,          15,         189, -1077760680],
        [          3,          16,           1,  1081681798],
        [          4,          16,           1,  1079942965],
        [          5,          16,           1,  1078807868],
        [          6,          16,           1,  1082282282],
        [          7,          16,           1,  1080307114],
        [          8,          16,           1,  1080573812],
        [          9,          16,           1,  1082305666],
        [         10,          16,           1,  1085392658],
        [         11,          16,           1,  1081433908],
        [         12,          16,           1,  1082813821],
        [         13,          18,           0,  1092706735],
        [         13,          16,           1, -1058648121],
        [         13,          17,         147,  1094443926],
        [         14,          16,           1,  1081785469],
        [         15,          16,           1,  1082840980],
        [         16,          19,          -1, -1221591040],
        [         17,          19,          -1, -1051153366],
        [         18,          19,          -1, -1051153366]],
       dtype=torch.int32)
lats.frame:  tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 3, 3, 3], dtype=torch.int32)
frame2state:  [[0], [1, 2], [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15], [16, 18, 17], [19]]
den score computed from new version:  tensor([36.6502, 24.7326, 10.5949], dtype=torch.float64)
diffence of the two version:  tensor([[13.5370, 13.5417, 13.5341]], dtype=torch.float64)
zrsjta commented 3 years ago

Note: For the last frame T-1, you have to consider the scores on the arcs entering the final state if those scores are not zero.

I'm also worried about the scores on the final states: If nnet_output[:, :t, :] is used in intersect_dense, states that belong to t-th frame should be considered as the final states and final scores should be considered. (In k2 you implement it by an additional final state with these final scores on some additional arcs). This is what happens in my original implementation. However, by tracing the lattice, states that belong to all frames (except the last frame) are not considered as final states and don't add the extra final scores. That may cause a difference in the scores.

csukuangfj commented 3 years ago

frame2state: [[0], [1, 2], [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15], [16, 18, 17], [19]]

[0] is the start state and does not belong to any frame. Looks like there is an offset-by-one error.

For arc i, we get its dest_state (not src_state) and frame_idx = lats.frame[i], then dest_state belongs to frame frame_idx.

zrsjta commented 3 years ago

I have dumped the lattice as below. if the frame is indexed from 1, i suppose we should do logsumexp on groups below. frame_1_states = [1, 2] frame_2_states = [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15] frame_3_states = [16, 18, 17]

image

[0] is the start state and does not belong to any frame. Looks like there is an offset-by-one error.

In my frame2state, the start and end states are the first and last entries respectively. Note the loop for t in range(T, 0, -1) with T=3 means only t=3, 2, 1 would be called, which would ignore both [0] and [19]

For arc i, we get its dest_state (not src_state) and frame_idx = lats.frame[i], then dest_state belongs to frame frame_idx.

Follow this advice, I have also revised my trace_frame function as below

def trace_frame(lats):
    arcs = lats.arcs.values()[:, :2]
    T = max(lats.frame).item()
    frame2state = [[] for _ in range(T+1)]

    for idx, (_, dst) in enumerate(arcs.tolist()):
        frame_idx = lats.frame[idx]
        if dst not in frame2state[frame_idx]:
            frame2state[frame_idx].append(dst)

    print("frame2state: ", frame2state)
    return frame2state

and the output of this is like below. It only ignore the start state [0] while other entries are the same frame2state: [[1, 2], [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15], [16, 18, 17], [19]]

Use the new frame2state, the difference between the two versions are unchanged:

den score computed from previous version:  tensor([[23.1133, 11.1909, -2.9392]], dtype=torch.float64)
frame2state:  [[1, 2], [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15], [16, 18, 17], [19]]
logsumexp on [16, 18, 17] is 36.650240146578824
logsumexp on [13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15] is 24.732586652356957
logsumexp on [1, 2] is 10.594917989250451
den score computed from new version:  tensor([36.6502, 24.7326, 10.5949], dtype=torch.float64)
diffence of the two version:  tensor([[13.5370, 13.5417, 13.5341]], dtype=torch.float64)

Note: I have changed assert len(frame2state) == T + 2 to assert len(frame2state) == T + 1 since [0] does not exits now I have changed states = frame2state[t] to states = frame2state[t-1] also because [0] does not exits now The log above presents all states in each logsumexp. Now I suppose we don't have the ``offset-by-one error

zrsjta commented 3 years ago

Note: For the last frame T-1, you have to consider the scores on the arcs entering the final state if those scores are not zero.

As shown in the lattice above, arc 17->19 and 18->19 have scores -13.54. So the scores on state 17 and 19 shoude be revised before logsumexp. (So far i don't do it) I'm curious whether a similar modification should be done for states that belong to other frames. As in my privous implementation, each frame would be considered as the final frame once. If needed, how to do this?

Thanks :)

csukuangfj commented 3 years ago

I just created a colab notebook (see https://colab.research.google.com/drive/1iyc_q8aHuKd-RZxtYv9EqfyjB2QZDSOx?usp=sharing) to verify the idea.

The following code you posted:

        graph, _ = self.graph_compiler.compile(texts, self.P, replicate_den=True)
        T = x.size()[1]
        tot_scores = []
        for t in range(T, 0, -1):
            supervision = torch.Tensor([[0, 0, t]]).to(torch.int32) # [idx, start, length]
            dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
            lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0)
            frame_score = lats.get_tot_scores(log_semiring=True, use_double_scores=True)
            tot_scores.append(frame_score)
        tot_scores = torch.cat(tot_scores)

is not equivalent to the one where the tot_scores is computed from the whole lattice (when you feed T frames at once)

zrsjta commented 3 years ago

I just created a colab notebook (see https://colab.research.google.com/drive/1iyc_q8aHuKd-RZxtYv9EqfyjB2QZDSOx?usp=sharing) to verify the idea.

The following code you posted:

        graph, _ = self.graph_compiler.compile(texts, self.P, replicate_den=True)
        T = x.size()[1]
        tot_scores = []
        for t in range(T, 0, -1):
            supervision = torch.Tensor([[0, 0, t]]).to(torch.int32) # [idx, start, length]
            dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
            lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0)
            frame_score = lats.get_tot_scores(log_semiring=True, use_double_scores=True)
            tot_scores.append(frame_score)
        tot_scores = torch.cat(tot_scores)

is not equivalent to the one where the tot_scores is computed from the whole lattice (when you feed T frames at once)

My observation is the same: the tot_scores computed by the two methods are different. Is there anything wrong with my code? or any suggestions on this mismatch?

csukuangfj commented 3 years ago

For the following dense_fsa_vec (you can find the code for all the below comments in the above colab notebook)

Screen Shot 2021-09-11 at 11 29 08

with the following decoding graph:

Screen Shot 2021-09-11 at 11 31 06

Feed T frames at once

If you feed T frames at once, the resulting lats is Screen Shot 2021-09-11 at 11 32 02

lats.frame and lats.get_forward_scores are

Screen Shot 2021-09-11 at 11 33 25

The tot_scores for frame 0 is

Screen Shot 2021-09-11 at 11 34 14

(Note: State 1 and state 2 belong to frame 0, so torch.tensor([1, 2]) is used above)

Feed frames separately

(1) If you feed only 1 frame, i.e., frame 0, the resulting lats is

Screen Shot 2021-09-11 at 11 35 56

Caution: The lats is different from the one when you feed 3 frames at once, so the tot_scores computed from this lats is different from the one computed from the whole lattice.

(2) If you feed only 2 frames, i.e., frame 0 and frame 1, the resulting lats is

Screen Shot 2021-09-11 at 11 37 23

Caution: The states for frame 0 are the same as the ones contained in the whole lattice. However, the states for frame 1 are different from that contained in the whole lattice. So the tot_scores computed for frame 1 using this lattice differs from the one computed from the whole lattice.

Note: tot_scores for frame 0 computed from this lats is the same as the one computed from the whole lattice.


So in the initial version of your code, I would recommend you to feed t+1 frames to compute the tot_scores for frame t. If you feed only t frames, then the result is not correct.

zrsjta commented 3 years ago

If I don't misunderstand: Given the lattice dumped by t+1 frames, only the first t frames of this lattice is a sub-graph of the whole lattice while states that belong to t+1-th frame should be ignored.

If this is true, I should compute the scores on the first T-1 frames by the logsumexp method while scores for T-th frame are obtained from the scores on the final state.

Also, I should never call lats.get_tot_scores in my implementation.

I'm still curious about: As defined by LF-MMI, the probability P(W|O_{1:t}) is the difference between the numerator and denominator scores. If I use only the first t frames to get the lattice and then use lats.get_tot_scores for both the numerator and denominator to get these scores, is it wrong in this scenario? In other words, would the lats.get_tot_scores lead to the wrong number in some concept for P(W|O_{1:t}) if t is not equal to T?

csukuangfj commented 3 years ago

In other words, would the lats.get_totscores lead to the wrong number in some concept for P(W|O{1:t}) if t is not equal to T?

Maybe @danpovey has more to say about it.


Given the lattice dumped by t+1 frames, only the first t frames of this lattice is a sub-graph of the whole lattice while states that belong to t+1-th frame should be ignored.

I was explaining why the tot_scores obtained by feeding frames separately is different from the one computed from the whole lattice. If you want to get identical results, you have to feed t+1 frames to compute the tot_scores for the t-th frame.

zrsjta commented 3 years ago

Thanks!

As @csukuangfj advised, I have rewritten the two methods. currently, the results can match. The key point is that:

  1. for the last frame, we use the last element of forward_scores (scores on the final state)
  2. for other frames, we feed t+1 frames to obtain the scores on t-th frame
def trace_lattice(lats):
    arcs = lats.arcs.values()[:, :2]
    T = max(lats.frame).item()
    frame2state = [[] for _ in range(T+1)]

    for idx, (_, dst) in enumerate(arcs.tolist()):
        frame_idx = lats.frame[idx]
        if dst not in frame2state[frame_idx]:
            frame2state[frame_idx].append(dst)

    return frame2state

def compute_frame_level_scores(graph, nnet_output):
    T = nnet_output.size()[1]

    # dump lattice
    supervision = torch.Tensor([[0, 0, T]]).to(torch.int32)
    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
    lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0,\
           seqframe_idx_name='seqframe', frame_idx_name='frame')

    # compute frame-level scores
    forward_scores = lats.get_forward_scores(True, True)
    frame2states = trace_lattice(lats)
    assert len(frame2states) == T + 1 # extra final state

    tot_scores = []
    for t in range(T, 0, -1):
        # scores for the last frame
        if t == T:
            tot_scores.append(forward_scores[-1])

        # scores for other frames
        else:
            states = frame2states[t-1]
            frame_score = torch.logsumexp(forward_scores[states], dim=-1)
            tot_scores.append(frame_score)
            print(f"scores computed from states {states} is {frame_score}")
    tot_scores = torch.stack(tot_scores, dim=0)

    return tot_scores

def compute_frame_level_scores_loop(graph, nnet_output):
    T = nnet_output.size()[1]

    tot_scores = []
    for t in range(T, 0, -1):
        # feed one more frame if it's not the last frame
        # so the states in first t frames is identical to
        # the those in whole lattice
        t_ = t if t == T else t + 1
        supervision = torch.Tensor([[0, 0, t_]]).to(torch.int32)
        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
        lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0,\
               seqframe_idx_name='seqframe', frame_idx_name='frame')

        forward_scores = lats.get_forward_scores(True, True)
        frame2states = trace_lattice(lats)

        if t == T:
            tot_scores.append(forward_scores[-1])
        else:
            assert len(frame2states) == t + 2
            states = frame2states[t-1]
            frame_score = torch.logsumexp(forward_scores[states], dim=-1)
            tot_scores.append(frame_score)
            print(f"scores computed from states {states} is {frame_score}")
    tot_scores = torch.stack(tot_scores, dim=0)
    return tot_scores

if __name__ == '__main__':
    nnet_output = torch.tensor(
    [
     [0.1, 0.22, 0.28, 0.4],
     [0.1, 0.13, 0.07, 0.7],
     [0.6, 0.2, 0.05, 0.15],
    ], dtype=torch.float32
    ).unsqueeze(0)
    nnet_output = torch.nn.functional.log_softmax(nnet_output, -1)

    graph = k2.ctc_graph([[1]])

    scores = compute_frame_level_scores(graph, nnet_output)
    print("Scores computed by new version: ", scores)

    scores = compute_frame_level_scores_loop(graph, nnet_output)
    print("Scores computed by original version: ", scores)

The results:

scores computed from states [3, 4, 5] is -1.652332052730104
scores computed from states [1, 2] is -0.787191693952512
Scores computed by new version:  tensor([-2.4776, -1.6523, -0.7872], dtype=torch.float64)
scores computed from states [3, 4, 5] is -1.652332052730104
scores computed from states [1, 2] is -0.787191693952512
Scores computed by original version:  tensor([-2.4776, -1.6523, -0.7872], dtype=torch.float64)

Maybe @danpovey has more to say about it. Looking forward to dan's reply :)