harvardnlp / pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
http://harvardnlp.github.io/pytorch-struct
MIT License
1.11k stars 92 forks source link

up sweep and down sweep #85

Closed allanj closed 3 years ago

allanj commented 3 years ago

I'm interested in the parallel scan algorithm for the linear-chain CRF.

I read the related paper in the tutorial and found that there are two steps: up sweep and down sweep in order to obtain all-prefix-sum.

I think in this case, we use that algorithm to obtain all Z(x) with different lengths in a batch. But seems I couldn't find out the down sweep code in the repo. Can you point me out there?

srush commented 3 years ago

So this is the really neat part of the library. Because it is written in pytorch, we never code the downsweep. It is run by autodifferentatiion / backprop. All the properties that you would normally collect in that pass can be automatically collected by overloading the backward functions used in the chain rule.

allanj commented 3 years ago

This sounds pretty exciting, it took me some time to code the downsweep. I actually would like to know more about the details besides the content in the tutorial. May I know where the backward is called in this repo?

srush commented 3 years ago

Sure. I'll try to find you a good reference. I think this is helpful as a start https://www.cs.jhu.edu/~jason/papers/eisner.spnlp16.pdf

allanj commented 3 years ago

Thanks. I actually read that before. I think I roughly understand the idea behind though I understand these from the hypergraph perspective. Thus, it somehow makes me a bit difficult to understand the logic of the code (in my hypergraph perspective). Any suggestions to get started for understanding the code of this package besides using it?

srush commented 3 years ago

Good question. Yes I should write something up about this.

However, if you know hypergraphs, you are already there! Think of parallel scan as turning linear-chain CRF into a hypergraph. The hypergraph is shaped like a balanced tree. Each node (i, j) is a labeled by its i and j + 1 state and a hyperedge joins together a interlocking nodes. So (0 -a , 0-b) (1-b, 1-c) => (0-a, 1-c) and (2 -c, 2-d) (3-d, 3-e) => (2-c, 3-e) and then (0-a, 1-c) (2-c, 3-e) => (0-a, 3-e).

Under this hypergraph, the inside algorithm is the standard parallel scan up sweep and the outside algorithm is the down-sweep. Time is linear in the size of the hypergraph so it is T log T (count the hyperedges above)