zhiweihu1103 / QE-TEMP

[IJCAI2022] Type-aware Embeddings for Multi-Hop Reasoning over Knowledge Graphs
28 stars 3 forks source link

Inductive testing in QE-TEMP #3

Closed leiloong closed 2 years ago

leiloong commented 2 years ago

I tried the data Ind-FB15k-237-V2 with gqe_temp model. 2022-09-27 19:00:00 INFO Geo: vec 2022-09-27 19:00:00 INFO Data Path: ../data-operation/Ind-FB15k-237-V2 2022-09-27 19:00:00 INFO #entity: 4268 2022-09-27 19:00:00 INFO #relation: 400 2022-09-27 19:00:00 INFO #max steps: 1000000 2022-09-27 19:00:00 INFO Evaluate unoins using: DNF

But fail to get acceptable testing results: 2022-09-28 08:38:47 INFO Evaluating the model... (0/1109) 2022-09-28 08:39:03 INFO Evaluating the model... (1000/1109) 2022-09-28 08:39:05 INFO Valid 1p MRR at step 120000: 0.517244 2022-09-28 08:39:05 INFO Valid 1p HITS1 at step 120000: 0.405023 2022-09-28 08:39:05 INFO Valid 1p HITS3 at step 120000: 0.583778 2022-09-28 08:39:05 INFO Valid 1p HITS10 at step 120000: 0.725092 2022-09-28 08:39:05 INFO Valid 1p num_queries at step 120000: 1738.000000 2022-09-28 08:39:05 INFO Valid 2p MRR at step 120000: 0.279630 2022-09-28 08:39:05 INFO Valid 2p HITS1 at step 120000: 0.187167 2022-09-28 08:39:05 INFO Valid 2p HITS3 at step 120000: 0.308840 2022-09-28 08:39:05 INFO Valid 2p HITS10 at step 120000: 0.470859 2022-09-28 08:39:05 INFO Valid 2p num_queries at step 120000: 2000.000000 2022-09-28 08:39:05 INFO Valid 3p MRR at step 120000: 0.277094 2022-09-28 08:39:05 INFO Valid 3p HITS1 at step 120000: 0.189872 2022-09-28 08:39:05 INFO Valid 3p HITS3 at step 120000: 0.300459 2022-09-28 08:39:05 INFO Valid 3p HITS10 at step 120000: 0.455469 2022-09-28 08:39:05 INFO Valid 3p num_queries at step 120000: 2000.000000 2022-09-28 08:39:05 INFO Valid 2i MRR at step 120000: 0.532246 2022-09-28 08:39:05 INFO Valid 2i HITS1 at step 120000: 0.415528 2022-09-28 08:39:05 INFO Valid 2i HITS3 at step 120000: 0.599195 2022-09-28 08:39:05 INFO Valid 2i HITS10 at step 120000: 0.749521 2022-09-28 08:39:05 INFO Valid 2i num_queries at step 120000: 2000.000000 2022-09-28 08:39:05 INFO Valid 3i MRR at step 120000: 0.694548 2022-09-28 08:39:05 INFO Valid 3i HITS1 at step 120000: 0.590300 2022-09-28 08:39:05 INFO Valid 3i HITS3 at step 120000: 0.770055 2022-09-28 08:39:05 INFO Valid 3i HITS10 at step 120000: 0.873482 2022-09-28 08:39:05 INFO Valid 3i num_queries at step 120000: 2000.000000 2022-09-28 08:39:05 INFO Valid pi MRR at step 120000: 0.406914 2022-09-28 08:39:05 INFO Valid pi HITS1 at step 120000: 0.312425 2022-09-28 08:39:05 INFO Valid pi HITS3 at step 120000: 0.440405 2022-09-28 08:39:05 INFO Valid pi HITS10 at step 120000: 0.586286 2022-09-28 08:39:05 INFO Valid pi num_queries at step 120000: 2000.000000 2022-09-28 08:39:05 INFO Valid ip MRR at step 120000: 0.303818 2022-09-28 08:39:05 INFO Valid ip HITS1 at step 120000: 0.244090 2022-09-28 08:39:05 INFO Valid ip HITS3 at step 120000: 0.316942 2022-09-28 08:39:05 INFO Valid ip HITS10 at step 120000: 0.419875 2022-09-28 08:39:05 INFO Valid ip num_queries at step 120000: 2000.000000 2022-09-28 08:39:05 INFO Valid 2u-DNF MRR at step 120000: 0.118458 2022-09-28 08:39:05 INFO Valid 2u-DNF HITS1 at step 120000: 0.054386 2022-09-28 08:39:05 INFO Valid 2u-DNF HITS3 at step 120000: 0.123657 2022-09-28 08:39:05 INFO Valid 2u-DNF HITS10 at step 120000: 0.238801 2022-09-28 08:39:05 INFO Valid 2u-DNF num_queries at step 120000: 2000.000000 2022-09-28 08:39:05 INFO Valid up-DNF MRR at step 120000: 0.156437 2022-09-28 08:39:05 INFO Valid up-DNF HITS1 at step 120000: 0.079643 2022-09-28 08:39:05 INFO Valid up-DNF HITS3 at step 120000: 0.176843 2022-09-28 08:39:05 INFO Valid up-DNF HITS10 at step 120000: 0.298912 2022-09-28 08:39:05 INFO Valid up-DNF num_queries at step 120000: 2000.000000 2022-09-28 08:39:05 INFO Valid average MRR at step 120000: 0.365154 2022-09-28 08:39:05 INFO Valid average HITS1 at step 120000: 0.275382 2022-09-28 08:39:05 INFO Valid average HITS3 at step 120000: 0.402242 2022-09-28 08:39:05 INFO Valid average HITS10 at step 120000: 0.535366 2022-09-28 08:39:05 INFO Evaluating on Test Dataset... 2022-09-28 08:39:07 INFO Evaluating the model... (0/550) 2022-09-28 08:39:17 INFO Test 1p MRR at step 120000: 0.000296 2022-09-28 08:39:17 INFO Test 1p HITS1 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 1p HITS3 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 1p HITS10 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 1p num_queries at step 120000: 791.000000 2022-09-28 08:39:17 INFO Test 2p MRR at step 120000: 0.000296 2022-09-28 08:39:17 INFO Test 2p HITS1 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 2p HITS3 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 2p HITS10 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 2p num_queries at step 120000: 1000.000000 2022-09-28 08:39:17 INFO Test 3p MRR at step 120000: 0.000295 2022-09-28 08:39:17 INFO Test 3p HITS1 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 3p HITS3 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 3p HITS10 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 3p num_queries at step 120000: 1000.000000 2022-09-28 08:39:17 INFO Test 2i MRR at step 120000: 0.000299 2022-09-28 08:39:17 INFO Test 2i HITS1 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 2i HITS3 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 2i HITS10 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 2i num_queries at step 120000: 1000.000000 2022-09-28 08:39:17 INFO Test 3i MRR at step 120000: 0.000300 2022-09-28 08:39:17 INFO Test 3i HITS1 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 3i HITS3 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 3i HITS10 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 3i num_queries at step 120000: 1000.000000 2022-09-28 08:39:17 INFO Test pi MRR at step 120000: 0.000299 2022-09-28 08:39:17 INFO Test pi HITS1 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test pi HITS3 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test pi HITS10 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test pi num_queries at step 120000: 1000.000000 2022-09-28 08:39:17 INFO Test ip MRR at step 120000: 0.000297 2022-09-28 08:39:17 INFO Test ip HITS1 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test ip HITS3 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test ip HITS10 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test ip num_queries at step 120000: 1000.000000 2022-09-28 08:39:17 INFO Test 2u-DNF MRR at step 120000: 0.000299 2022-09-28 08:39:17 INFO Test 2u-DNF HITS1 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 2u-DNF HITS3 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 2u-DNF HITS10 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test 2u-DNF num_queries at step 120000: 1000.000000 2022-09-28 08:39:17 INFO Test up-DNF MRR at step 120000: 0.000293 2022-09-28 08:39:17 INFO Test up-DNF HITS1 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test up-DNF HITS3 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test up-DNF HITS10 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test up-DNF num_queries at step 120000: 1000.000000 2022-09-28 08:39:17 INFO Test average MRR at step 120000: 0.000297 2022-09-28 08:39:17 INFO Test average HITS1 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test average HITS3 at step 120000: 0.000000 2022-09-28 08:39:17 INFO Test average HITS10 at step 120000: 0.000000

Are there some settings I should change in the script file beyond the location of data files?

zhiweihu1103 commented 2 years ago

For the inductive test, a simple modification to the code is required. Taking GQE as an example, some code needs to be added to the forward_vec section. I will post my original code here. You can make simple modifications and add them in other models like Q2B, BetaE, and LogicE. The main modification is that you need to change the positive_embedding and negative_embedding with the type information.

    def forward_vec(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict):
        all_center_embeddings, all_idxs = [], []
        all_union_center_embeddings, all_union_idxs = [], []
        for query_structure in batch_queries_dict:
            if 'u' in self.query_name_dict[query_structure]:
                center_embedding, _ = self.embed_query_vec(self.transform_union_query(batch_queries_dict[query_structure],
                                                                    query_structure), 
                                                                self.transform_union_structure(query_structure), 0)
                all_union_center_embeddings.append(center_embedding)
                all_union_idxs.extend(batch_idxs_dict[query_structure])
            else:
                center_embedding, _ = self.embed_query_vec(batch_queries_dict[query_structure], query_structure, 0)
                all_center_embeddings.append(center_embedding)
                all_idxs.extend(batch_idxs_dict[query_structure])

        if len(all_center_embeddings) > 0:
            all_center_embeddings = torch.cat(all_center_embeddings, dim=0).unsqueeze(1)
        if len(all_union_center_embeddings) > 0:
            all_union_center_embeddings = torch.cat(all_union_center_embeddings, dim=0).unsqueeze(1)
            all_union_center_embeddings = all_union_center_embeddings.view(all_union_center_embeddings.shape[0]//2, 2, 1, -1)

        if type(subsampling_weight) != type(None):
            subsampling_weight = subsampling_weight[all_idxs+all_union_idxs]

        if type(positive_sample) != type(None):
            if len(all_center_embeddings) > 0:
                positive_sample_regular = positive_sample[all_idxs]

                # positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_regular)
                type_id = torch.index_select(self.entity2types, dim=0, index=positive_sample_regular)
                entity_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=type_id.view(-1)).view(type_id.shape[0], type_id.shape[1], -1)
                entity_neighbor_type_embedding = torch.transpose(entity_neighbor_type_embedding, 0, 1)
                if self.type_mode == 'ent' or self.type_mode == 'ent_rel':
                    positive_embedding = self.center_net(entity_neighbor_type_embedding)

                positive_embedding = positive_embedding.unsqueeze(1)
                positive_logit = self.cal_logit_vec(positive_embedding, all_center_embeddings)
            else:
                positive_logit = torch.Tensor([]).to(self.entity_embedding.device)

            if len(all_union_center_embeddings) > 0:
                positive_sample_union = positive_sample[all_union_idxs]

                # positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1)
                type_id = torch.index_select(self.entity2types, dim=0, index=positive_sample_union)
                entity_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=type_id.view(-1)).view(type_id.shape[0], type_id.shape[1], -1)
                entity_neighbor_type_embedding = torch.transpose(entity_neighbor_type_embedding, 0, 1)
                if self.type_mode == 'ent' or self.type_mode == 'ent_rel':
                    positive_embedding = self.center_net(entity_neighbor_type_embedding)
                positive_embedding = positive_embedding.unsqueeze(1).unsqueeze(1)

                positive_union_logit = self.cal_logit_vec(positive_embedding, all_union_center_embeddings)
                positive_union_logit = torch.max(positive_union_logit, dim=1)[0]
            else:
                positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device)
            positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0)
        else:
            positive_logit = None

        if type(negative_sample) != type(None):
            if len(all_center_embeddings) > 0:
                negative_sample_regular = negative_sample[all_idxs]
                batch_size, negative_size = negative_sample_regular.shape

                # negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view(batch_size, negative_size, -1)
                type_id = torch.index_select(self.entity2types, dim=0, index=negative_sample_regular.view(-1))
                entity_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=type_id.view(-1)).view(batch_size, negative_size, type_id.shape[1], -1)
                entity_neighbor_type_embedding = torch.transpose(entity_neighbor_type_embedding, 0, 1)
                entity_neighbor_type_embedding = torch.transpose(entity_neighbor_type_embedding, 0, 2)
                if self.type_mode == 'ent' or self.type_mode == 'ent_rel':
                    negative_embedding = self.center_net(entity_neighbor_type_embedding)
                negative_embedding = negative_embedding.view(batch_size, negative_size,-1)

                negative_logit = self.cal_logit_vec(negative_embedding, all_center_embeddings)
            else:
                negative_logit = torch.Tensor([]).to(self.entity_embedding.device)

            if len(all_union_center_embeddings) > 0:
                negative_sample_union = negative_sample[all_union_idxs]
                batch_size, negative_size = negative_sample_union.shape

                # negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_union.view(-1)).view(batch_size, negative_size, -1)
                type_id = torch.index_select(self.entity2types, dim=0, index=negative_sample_union.view(-1))
                entity_neighbor_type_embedding = torch.index_select(self.type_embedding, dim=0, index=type_id.view(-1)).view(batch_size, negative_size, type_id.shape[1], -1)
                entity_neighbor_type_embedding = torch.transpose(entity_neighbor_type_embedding, 0, 1)
                entity_neighbor_type_embedding = torch.transpose(entity_neighbor_type_embedding, 0, 2)
                if self.type_mode == 'ent' or self.type_mode == 'ent_rel':
                    negative_embedding = self.center_net(entity_neighbor_type_embedding)
                negative_embedding = negative_embedding.view(batch_size, negative_size, -1).unsqueeze(1)

                negative_union_logit = self.cal_logit_vec(negative_embedding, all_union_center_embeddings)
                negative_union_logit = torch.max(negative_union_logit, dim=1)[0]
            else:
                negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device)
            negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0)
        else:
            negative_logit = None

        return positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs
leiloong commented 2 years ago

Thanks a lot for your very quick reply. After add self.type_mode = 'ent' in KGReasoning class, it seems there are no errors now. However, I meet a RuntimeError instead. RuntimeError: CUDA out of memory. Tried to allocate 6.25 GiB (GPU 0; 31.75 GiB total capacity; 27.23 GiB already allocated; 3.07 GiB free; 27.26 GiB reserved in total by PyTorch) What about the gpu memory needed for inductive testing? I will check the causes soon, and if it is possible that you will update the models.py in this Github repository soon?

zhiweihu1103 commented 2 years ago

For the Inductive test, the batch needs to be smaller, because the memory space is occupied when the entity is represented in forward_vec. I set batch_size=128 on 32G video memory here, you can try. I may need two to three weeks to update the code, and now need to dare the deadline.

leiloong commented 2 years ago

OK, and thanks a lot, I think I get the sense of the code given, "not only entities in the queries but also entities in answer should also be augmented before scoring", will check it later.