deepakn97 / relationPrediction

ACL 2019: Learning Attention-based Embeddings for Relation Prediction in Knowledge Graphs
521 stars 124 forks source link

What is the role of "mask" in models.py #15

Closed liuslnlp closed 4 years ago

liuslnlp commented 4 years ago
mask_indices = torch.unique(batch_inputs[:, 2]).cuda()
mask = torch.zeros(self.entity_embeddings.shape[0]).cuda()
mask[mask_indices] = 1.0

I want to know what is the role of "mask", and it doesn't appear in paper.

roholazandie commented 4 years ago

I think by doing this you make sure when you are updating the embeddings of entities you just update those that are really were in the sparse_gat. For others, the mask will be zero and they will be passed by just a multiplication of W_entities.

chauhanjatin10 commented 4 years ago

Hi @WiseDoge and @roholazandie . Thanks for showing interest in our work. @roholazandie is right. Depending upon the availability of cuda memory, one can choose to update only a few selected entities rather than the whole graph. The mask tensor serves the purpose for zeroing out the gradients for the remaining entities, which we don't wish to update in the current batch iteration. I hope this clarifies the issue. Thanks