k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

pruned_ragged_to_lattice #1163

Closed glynpu closed 1 year ago

glynpu commented 1 year ago

Hi Dan, this is the code we are discussing just now. @danpovey Could we get an alignment you mentioned from this lattice?

Maybe we need more unit tests to make sure it works as we expect. Currently, this is only checked by Xiaoyu's and My eyes with following code.

import k2
import torch
ranges = torch.tensor([[[0, 1, 2, 3, 4],[1,2,3,4,5],[2,3,4,5,6],[3,4,5,6,7]],                                                                                               
                       [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [3, 4, 5, 6, 7], [3,4,5,6,7]]]).to(torch.int32)                                                                   
x_lens = torch.tensor([4, 3]).to(torch.int32)                                                                                                                               
y = torch.tensor([[8, 7, 6, 5, 4, 3, 2],                                                                                                                                    
                  [2, 3, 4, 5, 6, 7, 8]]).to(torch.int32)                                                                                                                   
logits = torch.tensor([0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]).expand(2, 4, 5, 9).to(torch.float32)
logits = logits + torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]).reshape(2, 4).unsqueeze(-1).unsqueeze(-1)
logits = logits + torch.tensor([0.0, 1, 2, 3, 4]).unsqueeze(0).unsqueeze(0).unsqueeze(-1)                                                                                                                                                                      

ofsa, arc_map  = k2.self_alignment(ranges, x_lens, y, logits)                                                                                                               
lattice = k2.Fsa(ofsa)                                                                                                                                                      
scores_tracked_by_autograd = k2.index_select(logits.reshape(-1), arc_map)                                                                                                   
assert torch.all(lattice.scores == scores_tracked_by_autograd)

lattice.scores = scores_tracked_by_autograd

Lattice generated:

image
danpovey commented 1 year ago

So I assume best_path would be done after this. There definitely should be a way to get the alignment info, either by somehow tracing back the arc_maps (I don't remember whether these are made available), or perhaps more elegantly, by attaching some kind of integer properties about the frame indexes and text-position indexes to the FSA at the point we create it, and then accessing those after best-path. Perhaps someone can figure out how to do this.
I think this is something we will find a lot of uses for, but we need to have a function in icefall that makes it available in an easy way. We could perhaps have an option to the RNN-T training code, to return the alignment.

marcoyang1998 commented 1 year ago

have a function in icefall

Actually, I am using this feature for a while and I have a function in python that converts the alignment to timestamp information. I will make a PR in icefall.

danpovey commented 1 year ago

Great! Bear in mind we may need the scores/probabilities as well.

danpovey commented 1 year ago

Looks OK to me from a brief glance! So are we OK to merge it? Let's merge today unless there are any immediate objections?

pkufool commented 1 year ago

@glynpu Could you have a look to check if the failing cases relate to this PR.

glynpu commented 1 year ago

@glynpu Could you have a look to check if the failing cases relate to this PR.

The failing cases are not related to this PR. They are mainly about torch.1.13.1 installation problem.

pkufool commented 1 year ago

I think it is a fairly safe and independent change (i.e. will not affect other functions), merging now.