Open angusturner opened 1 year ago
Actually, while I'm here, can I also clarify the interpretation of the dimensions, referenced in the docs?
event_shape (N x M x 3), e.g.
phi(i, j, op)
Ops are 0 -> j-1, 1->i-1,j-1, and 2->i-1
I am bit confused how to interpret this. For example, is the interpretation of phi[i, j, 0]
something like "Given that we are at frame j in state i, what is the log-probability we arrived from i, j-1
" ?
Oh interesting. Yes, I should update these examples for PyTorch 2. Might speed things up a lot.
Furthermore, I wonder if there is a bug in the argmax decoding shown in the CTC notebook, where it seems like one of the frames is aligned to two characters? (Unless I am misinterpreting this plot).
I don't think it's a bug. I guess you're right that in speech you would never want this to happen. I guess the way I set up the problem I didn't forbid this behavior. You could do so by setting the "down step" motion (i-1, j) to -inf.
For example, is the interpretation of phi[i, j, 0] something like "Given that we are at frame j in state i, what is the log-probability we arrived from i, j-1" ?
The way the model is specified is as a CRF which means it is globally normalized. So you can set these as log-probs if you want, but they can be any score. The algorithm computes p(alignment) which is equivalent to exp(sum of scores along chosen path) / \sum_{x \in allpaths} exp(sum of scores in path x).
Speed
I spend a ton of time trying to make this fast in pytorch and think I eventually gave up. I think this one in JAX is probably a better bet https://github.com/spetti/SMURF .
Hi,
Firstly, just wanted to say this is a really cool library. I have been working on some CTC/alignment research and when i saw this trick with the parallel-scan and semi-ring it struck me as a very elegant solution.
I know the CTC example is a bit out of date (as referenced in other issues), but I am wondering how involved it is to fix it? I am hoping to compare the answers. Partly for my own understanding, and partly to see what speedups I can get from the parallel scan + custom cuda kernels.
Furthermore, I wonder if there is a bug in the argmax decoding shown in the CTC notebook, where it seems like one of the frames is aligned to two characters? (Unless I am misinterpreting this plot).
Would really appreciate any pointers with this if you get a chance.