JinheonBaek / RGCN

Pytorch implementation of a RGCN Link Prediction Model
242 stars 41 forks source link

The error related to one-hot occurs when running the main.py file,why,please #13

Open viviliuwqhduhnwqihwqwudceygysjiwuwnn opened 2 months ago

viviliuwqhduhnwqihwqwudceygysjiwuwnn commented 2 months ago

Namespace(dropout=0.2, evaluate_every=500, gpu=-1, grad_norm=1.0, graph_batch_size=30000, graph_split_size=0.5, lr=0.01, n_bases=4, n_epochs=10000, negative_sample=1, regularization=0.01) load data from ./data/FB15k-237 num_entity: 14541 num_relation: 237 num_train_triples: 272115 num_valid_triples: 17535 num_test_triples: 20466 Traceback (most recent call last): File "D:/RGCN-master/main.py", line 121, in main(args) File "D:/RGCN-master/main.py", line 48, in main test_graph = build_test_graph(len(entity2id), len(relation2id), train_triplets) File "D:\RGCN-master\utils.py", line 159, in build_test_graph data.edge_norm = edge_normalization(edge_type, edge_index, num_nodes, num_rels) File "D:\RGCN-master\utils.py", line 91, in edge_normalization one_hot = F.one_hot(edge_type, num_classes = 2 * num_relation).to(torch.float) RuntimeError: one_hot is only applicable to index tensor.

Process finished with exit code 1

xujunche commented 1 month ago

You can change the first few lines of the function build_test_graph fromsrc = torch.from_numpy(src) rel = torch.from_numpy(rel) dst = torch.from_numpy(dst) to src = torch.from_numpy(src).to(torch.long) rel = torch.from_numpy(rel).to(torch.long) dst = torch.from_numpy(dst).to(torch.long)