DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.44k stars 200 forks source link

Question about the negative example of KnowledgeGraphCompletion Class #108

Closed AlexHex7 closed 2 years ago

AlexHex7 commented 2 years ago

Hi,

In the _strict_negative method function of KnowledgeGraphCompletion, if 'A-->B', 'B-->C' (A and B are entities, --> is relation) are the samples of traning set (.i.e. self.fact_graph) while 'A-->C' is the sample of valiation set, then I think 'A-->C' will be regard as a negative sample in the traning stage. Is that a problem?

@torch.no_grad()
def _strict_negative(self, pos_h_index, pos_t_index, pos_r_index):
    batch_size = len(pos_h_index)
    any = -torch.ones_like(pos_h_index)

    pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1)
    pattern = pattern[:batch_size // 2]

    # ==================== Code I Talk About ======================
    edge_index, num_t_truth = self.fact_graph.match(pattern)
    t_truth_index = self.fact_graph.edge_list[edge_index, 1]
    pos_index = functional._size_to_index(num_t_truth)
    t_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device)
    t_mask[pos_index, t_truth_index] = 0
    neg_t_candidate = t_mask.nonzero()[:, 1]
    num_t_candidate = t_mask.sum(dim=-1)
    neg_t_index = functional.variadic_sample(neg_t_candidate, num_t_candidate, self.num_negative)
    # =======================================================

    pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1)
    pattern = pattern[batch_size // 2:]
    edge_index, num_h_truth = self.fact_graph.match(pattern)
    h_truth_index = self.fact_graph.edge_list[edge_index, 0]
    pos_index = functional._size_to_index(num_h_truth)
    h_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device)
    h_mask[pos_index, h_truth_index] = 0
    neg_h_candidate = h_mask.nonzero()[:, 1]
    num_h_candidate = h_mask.sum(dim=-1)
    neg_h_index = functional.variadic_sample(neg_h_candidate, num_h_candidate, self.num_negative)

    neg_index = torch.cat([neg_t_index, neg_h_index])

    return neg_index
KiddoZhu commented 2 years ago

Yes. A-->C will be considered as a valid negative sample, no matter strict_negative is set to True or False. Note that filtering any sample from validation / test sets is considered to be a test data leakage. That's why we only filter samples by self.fact_graph.

AlexHex7 commented 2 years ago

Yes. A-->C will be considered as a valid negative sample, no matter strict_negative is set to True or False. Note that filtering any sample from validation / test sets is considered to be a test data leakage. That's why we only filter samples by self.fact_graph.

@KiddoZhu Thanks for your reply. Yes! I quite agree with what you said about the test data leakage.

In fact, I'm not familiar with the research area of Knowledge Graph Reasoning. What makes me feel confused is that, the sample 'A-->C' is considered as a negative sample in training stage, while it is actually a positive sample in the valiation set, which means that it may not be predicted as a positive sample by the trained model in validation stage. Will this have a significant impact?

KiddoZhu commented 2 years ago

Not much from my experience. The goal of knowledge graph reasoning is more like to denoise / smooth the knowledge graph with learned embeddings or logic rules. That means, we can use the learned model to predict missing links, or predict some existing links to be wrong -- where the former is more meaningful in applications and therefore we only evaluate it. Besides, it is impossible to avoid using validation links as negative samples, unless you know the validation links -- which triggers data leakage.

AlexHex7 commented 2 years ago

Not much from my experience. The goal of knowledge graph reasoning is more like to denoise / smooth the knowledge graph with learned embeddings or logic rules. That means, we can use the learned model to predict missing links, or predict some existing links to be wrong -- where the former is more meaningful in applications and therefore we only evaluate it. Besides, it is impossible to avoid using validation links as negative samples, unless you know the validation links -- which triggers data leakage.

@KiddoZhu I see. Thanks!