tsujuifu / pytorch_graph-rel

A PyTorch implementation of GraphRel
MIT License
268 stars 54 forks source link

There is a small BUG in the construction of the dep_bw(backward dependency parse tree) #26

Closed jinglin-liang closed 2 years ago

jinglin-liang commented 2 years ago

https://github.com/tsujuifu/pytorch_graph-rel/blob/9f0697d6fdcf2d7e013a98ed4c10f43bb86f105c/dataset.py#L42

The above regularization method does not regularize dep_bw on a per line basis, I think it would make more sense to change it to the following:

        for i, w in enumerate(res):
            inp_sent[i], inp_pos[i] = w.vector, self.POS[w.tag_]

            dep_fw[i][i], dep_bw[i][i] = 1, 1
            for c in res[i].children:
                for j, t in enumerate(res):
                    if c==t:
                        dep_fw[i][j], dep_bw[j][i] = 1, 1
            # dep_fw[i], dep_bw[i] = dep_fw[i]/sum(dep_fw[i]), dep_bw[i]/sum(dep_bw[i])
        dep_fw[:i+1] = dep_fw[:i+1] / dep_fw[:i+1].sum(axis=-1, keepdims=True)
        dep_bw[:i+1] = dep_bw[:i+1] / dep_bw[:i+1].sum(axis=-1, keepdims=True)
tsujuifu commented 2 years ago

Thanks for pointing this out! I have updated the code here.

jinglin-liang commented 2 years ago

You're welcome. Thanks for sharing, your code has helped me a lot in my learning.