Closed allanj closed 3 years ago
So this is the really neat part of the library. Because it is written in pytorch, we never code the downsweep. It is run by autodifferentatiion / backprop. All the properties that you would normally collect in that pass can be automatically collected by overloading the backward
functions used in the chain rule.
This sounds pretty exciting, it took me some time to code the downsweep. I actually would like to know more about the details besides the content in the tutorial.
May I know where the backward
is called in this repo?
Sure. I'll try to find you a good reference. I think this is helpful as a start https://www.cs.jhu.edu/~jason/papers/eisner.spnlp16.pdf
Thanks. I actually read that before. I think I roughly understand the idea behind though I understand these from the hypergraph perspective. Thus, it somehow makes me a bit difficult to understand the logic of the code (in my hypergraph perspective). Any suggestions to get started for understanding the code of this package besides using it?
Good question. Yes I should write something up about this.
However, if you know hypergraphs, you are already there! Think of parallel scan as turning linear-chain CRF into a hypergraph. The hypergraph is shaped like a balanced tree. Each node (i, j) is a labeled by its i and j + 1 state and a hyperedge joins together a interlocking nodes. So (0 -a , 0-b) (1-b, 1-c) => (0-a, 1-c) and (2 -c, 2-d) (3-d, 3-e) => (2-c, 3-e) and then (0-a, 1-c) (2-c, 3-e) => (0-a, 3-e).
Under this hypergraph, the inside algorithm is the standard parallel scan up sweep and the outside algorithm is the down-sweep. Time is linear in the size of the hypergraph so it is T log T (count the hyperedges above)
I'm interested in the parallel scan algorithm for the linear-chain CRF.
I read the related paper in the tutorial and found that there are two steps: up sweep and down sweep in order to obtain all-prefix-sum.
I think in this case, we use that algorithm to obtain all Z(x) with different lengths in a batch. But seems I couldn't find out the down sweep code in the repo. Can you point me out there?