The above regularization method does not regularize dep_bw on a per line basis, I think it would make more sense to change it to the following:
for i, w in enumerate(res):
inp_sent[i], inp_pos[i] = w.vector, self.POS[w.tag_]
dep_fw[i][i], dep_bw[i][i] = 1, 1
for c in res[i].children:
for j, t in enumerate(res):
if c==t:
dep_fw[i][j], dep_bw[j][i] = 1, 1
# dep_fw[i], dep_bw[i] = dep_fw[i]/sum(dep_fw[i]), dep_bw[i]/sum(dep_bw[i])
dep_fw[:i+1] = dep_fw[:i+1] / dep_fw[:i+1].sum(axis=-1, keepdims=True)
dep_bw[:i+1] = dep_bw[:i+1] / dep_bw[:i+1].sum(axis=-1, keepdims=True)
https://github.com/tsujuifu/pytorch_graph-rel/blob/9f0697d6fdcf2d7e013a98ed4c10f43bb86f105c/dataset.py#L42
The above regularization method does not regularize dep_bw on a per line basis, I think it would make more sense to change it to the following: