k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

Creating label checker fsa. #1155

Closed tomotious closed 1 year ago

tomotious commented 1 year ago

I'm going to create label check fsa using rnnt model. Here is the schematic diagram of the fsa. Screenshot from 2023-02-06 09-05-28

Here is my code to create the decoding fsa using rnnt model.

def compile_align_fsa(
    num_tokens: int,
    labels: List[int],
    ins_penalty: float,
    del_penalty: float,
    first_token_disambig_id,
    insert_start_wid: int,
    insert_end_wid: int,
    deletion_wid: int,
):

    arcs = []
    filler_start, filler_end = 1, 2
    eps = 0

    # create filter part, exclude 0 and 1, 0 means blank, 1 means <sos/eos>.
    for i in range(2, num_tokens):        
        arcs.append([filler_start, filler_end, i, i, ins_penalty])

    # loop state on filter.
    arcs.append([filler_end, filler_start, first_token_disambig_id, eps, 0])

    # alignment path and optional filler
    cur_state = 3
    prev_state = 0

    for label in labels:
        # correct        
        # insertion or substitution
        arcs.append([prev_state, filler_start, first_token_disambig_id, insert_start_wid, 0.0])
        arcs.append([filler_end, prev_state, first_token_disambig_id, insert_end_wid, 0.0])
        arcs.append([prev_state, cur_state, label, label, 0])
        # deletion
        arcs.append([prev_state, cur_state, first_token_disambig_id, deletion_wid, del_penalty])

        prev_state = cur_state
        cur_state += 1

    final_state = cur_state
    # optional add endding filler
    arcs.append([prev_state, filler_start, first_token_disambig_id, insert_start_wid, 0.0])
    arcs.append([filler_end, prev_state, first_token_disambig_id, insert_end_wid, 0.0])
    arcs.append([prev_state, final_state, -1, -1, 0])
    arcs.append([final_state])

    # sort arcs
    arcs = sorted(arcs, key=lambda arc: arc[0] )
    arcs = [[str(i) for i in arc] for arc in arcs]
    arcs = [" ".join(arc) for arc in arcs]
    arcs = "\n".join(arcs)

    fsa = k2.Fsa.from_str(arcs, acceptor=False)
    fsa.labels[fsa.labels >= first_token_disambig_id] = 0
    return fsa

test code to draw fsa.

    labels = [3, 4, 5]
    num_tokens = max(lexicon.tokens) + 1
    lexicon.token_table.add("#sis")
    lexicon.token_table.add("#eis")
    lexicon.token_table.add("#del")
    insert_start_wid = lexicon.token_table["#sis"]
    insert_end_wid = lexicon.token_table["#eis"]
    deletion_wid = lexicon.token_table["#del"]

    first_token_disambig_id = lexicon.token_table["#0"]

    align_fsa = compile_align_fsa(
        num_tokens=num_tokens,
        labels=labels,
        first_token_disambig_id=first_token_disambig_id,
        ins_penalty=-1.0,
        del_penalty=-1.0,
        insert_start_wid=insert_start_wid,
        insert_end_wid=insert_end_wid,
        deletion_wid=deletion_wid,
    )
    align_fsa.labels_sym = lexicon.token_table
    align_fsa.aux_labels_sym = lexicon.token_table
    align_fsa.draw('align_fsa.svg')

the resulting test fsa like this align_fsa

I put the code in the decode.py file, and use a audio with referece label "但是从尺寸来讲呢往往呢就这个尺寸一有一个城市都是大于一般欧洲的一个国家的". The decoding result with modified beam search method of the audio is "但是从尺寸来讲呢往往呢就这个尺寸有的一个城市都是大于一般欧洲的一个国家的". I decode the audio using fast_beam_search_one_best() with generated label label check fsa. I thought the rnnt model should give a result like "但是从尺寸来讲呢往往呢就这个尺寸#del有#sis的#eis一个城市都是大于一般欧洲的一个国家的". But the real result is "但是从尺寸来讲呢往往呢就这个尺寸#sis有#eis#sis的#eis#sis一#eis#sis个#eis#sis城#eis#sis市#eis#sis都#eis是#sis大#eis#sis于#eis#sis一#eis#sis般#eis#sis欧#eis#sis洲#eis#sis的#eis#sis一#eis#sis个#eis#sis国#eis#sis家#eis#sis的#eis". Why this would happen?