harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.1k stars 92 forks source link

Broken Alignmnent for CTC example ? #127

Open angusturner opened 1 year ago

angusturner commented 1 year ago

Hi,

Firstly, just wanted to say this is a really cool library. I have been working on some CTC/alignment research and when i saw this trick with the parallel-scan and semi-ring it struck me as a very elegant solution.

I know the CTC example is a bit out of date (as referenced in other issues), but I am wondering how involved it is to fix it? I am hoping to compare the answers. Partly for my own understanding, and partly to see what speedups I can get from the parallel scan + custom cuda kernels.

Furthermore, I wonder if there is a bug in the argmax decoding shown in the CTC notebook, where it seems like one of the frames is aligned to two characters? (Unless I am misinterpreting this plot).

Screen Shot 2023-07-05 at 10 50 20 pm

Would really appreciate any pointers with this if you get a chance.

angusturner commented 1 year ago

Actually, while I'm here, can I also clarify the interpretation of the dimensions, referenced in the docs?

event_shape (N x M x 3), e.g.
phi(i, j, op)
Ops are 0 -> j-1, 1->i-1,j-1, and 2->i-1

I am bit confused how to interpret this. For example, is the interpretation of phi[i, j, 0] something like "Given that we are at frame j in state i, what is the log-probability we arrived from i, j-1" ?

srush commented 1 year ago

Oh interesting. Yes, I should update these examples for PyTorch 2. Might speed things up a lot.

Furthermore, I wonder if there is a bug in the argmax decoding shown in the CTC notebook, where it seems like one of the frames is aligned to two characters? (Unless I am misinterpreting this plot).

I don't think it's a bug. I guess you're right that in speech you would never want this to happen. I guess the way I set up the problem I didn't forbid this behavior. You could do so by setting the "down step" motion (i-1, j) to -inf.

For example, is the interpretation of phi[i, j, 0] something like "Given that we are at frame j in state i, what is the log-probability we arrived from i, j-1" ?

The way the model is specified is as a CRF which means it is globally normalized. So you can set these as log-probs if you want, but they can be any score. The algorithm computes p(alignment) which is equivalent to exp(sum of scores along chosen path) / \sum_{x \in allpaths} exp(sum of scores in path x).

Speed

I spend a ton of time trying to make this fast in pytorch and think I eventually gave up. I think this one in JAX is probably a better bet https://github.com/spetti/SMURF .