leapxcheng / RawNP

Official code for RawNP (ECML-PKDD 2023)
17 stars 1 forks source link

RuntimeError #3

Open zhqihang opened 1 year ago

zhqihang commented 1 year ago

Hello author, what is the reason for this question?

RuntimeError: mat1 and mat2 shapes cannot be multiplied (495x20 and 100x100)

leapxcheng commented 1 year ago

Can you provide the detailed problem?

zhqihang commented 1 year ago

你好!我修改了model.py 279行的代码self.enc_rw = nn.Linear(100, embed_dim)为self.enc_rw = nn.Linear(20, embed_dim),但是又有了新的错误: Traceback (most recent call last): File "main.py", line 108, in trainer.train() File "main.py", line 56, in train results = self.model.eval_one_time(eval_type='valid') File "/mnt/nas/qihang/RawNP/neuralprocess.py", line 483, in eval_one_time ranks, ranks_1, ranks_2 = utils.calc_induc_mrr(test_task_pool, self.args, self.random_walk2, self.arw_encoder, self.decoder, z, unseen_entity, unseen_entity_embedding, self.embed.entity_embedding.weight, self.embed.relation_embedding, test_triplets, self.all_triplets, self.use_cuda, score_function=self.args.score_function) File "/mnt/nas/qihang/RawNP/utils.py", line 166, in calc_induc_mrr arw = rw_func(perturb_entity_index) File "/mnt/nas/qihang/RawNP/neuralprocess.py", line 253, in random_walk2 walks = self.anonymous_walk[entities] RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

zhqihang commented 1 year ago

将walks = self.anonymous_walk[entities]增加.('cuda')方法后 代码似乎可以跑起来,但是好像不对 use cuda load data from ./Dataset/raw_data/FB15k-237 num_entity: 14541 num_relation: 237 num_train_triples: 272115 num_valid_triples: 17535 num_test_triples: 20466 0%| | 0/25001 [00:00<?, ?it/s]-------------------------------------Valid--------------------------------------- 100%|█████████████████████████████████████████████| 1000/1000 [1:08:54<00:00, 4.13s/it] 4134.3834109306335████████████████████████████████| 1000/1000 [1:08:54<00:00, 2.53s/it] 0%| | 0/25001 [1:09:02<?, ?it/s] Traceback (most recent call last): File "main.py", line 108, in trainer.train() File "main.py", line 59, in train mrr = results['total_mrr'] KeyError: 'total_mrr'