torchkge-team / torchkge

TorchKGE: Knowledge Graph embedding in Python and PyTorch.
Other
381 stars 41 forks source link

KeyError when running data_redundancy.duplicates on WN18 dataset #265

Open galadrielbriere opened 2 months ago

galadrielbriere commented 2 months ago

Description

I encountered a KeyError when running the data_redundancy.duplicates function on the WN18 dataset. The code works correctly with other datasets (e.g., FB13), but fails with WN18, and with my own KG as well, with the same error.

Here is the code that triggers the error:

from torchkge.utils.datasets import load_wn18
from  torchkge.utils import data_redundancy

kg_train, kg_val, kg_test = load_wn18()
dups, reverse_dups = data_redundancy.duplicates(kg_train, kg_val, kg_test)

And the error:

{
    "name": "KeyError",
    "message": "18",
    "stack": "---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[5], line 5
      2 from  torchkge.utils import data_redundancy
      4 kg_train, kg_val, kg_test = load_wn18()
----> 5 dups, reverse_dups = data_redundancy.duplicates(kg_train, kg_val, kg_test)

File ~/anaconda3/envs/benchmark/lib/python3.9/site-packages/torchkge/utils/data_redundancy.py:150, in duplicates(kg_tr, kg_val, kg_te, theta1, theta2, verbose, counts, reverses)
    147 iter_ = list(combinations(range(1345), 2))
    149 for r1, r2 in tqdm(iter_):
--> 150     a = len(T[r1].intersection(T[r2])) / lengths[r1]
    151     b = len(T[r1].intersection(T[r2])) / lengths[r2]
    153     if a > theta1 and b > theta2:

KeyError: 18"
}

Potential solution

To my understanding, this can be fixed by modifying:

iter_ = list(combinations(range(1345), 2)) to iter_ = list(combinations(range(kg_tr.n_rel), 2))