awslabs / dgl-ke

High performance, easy-to-use, and scalable package for learning large-scale knowledge graph embeddings.
https://dglke.dgl.ai/doc/
Apache License 2.0
1.25k stars 194 forks source link

Force dtype to int64 to ensure that we don't index with non-long tensor #258

Open TobiasMadsenQiagen opened 2 years ago

TobiasMadsenQiagen commented 2 years ago

In the triplet data loaders (utils.py:load_triplet_data and utils.py:load_raw_triplet_data) the imported data must be forced to be of type int64, to ensure that torch tensors are always long. Otherwise torch may complain that a vector used for indexing is not of type long, when calling predict:

line 186, in __call__
return self.emb[idx].to(self.device)
IndexError: tensors used as indices must be long, byte or bool tensors

np.asarray tries to infer the data type for the input, which on the windows system we have tested on is int32 as long as the input ints are smaller than 2^31-1. On mac and ubuntu we did not observe the problem. We have tested with dglke 0.1.2.