Closed JiaangL closed 2 years ago
Now I'm able to run the code by reimplementing the torchdrug.layers.functional.generalized_rspmm
function by myself using python instead of cuda. It takes about 1.5 hours to train one epoch using RTX3090 on fb15k237-v1. Model converges after one or two epochs. My code is below and any suggestion is welcomed.
def my_generalized_rspmm(sparse, relation, input, sum='add', mul='mul'):
sparse_indices = sparse._indices()
sparse_values = sparse._values()
sparse_dict = dict()
for i in range(sparse_indices.shape[-1]):
key = sparse_indices[0, i].item()
value = sparse_indices[1:, i].tolist()
if key not in sparse_dict.keys():
sparse_dict[key] = [value]
else:
sparse_dict[key].append(value)
output = torch.zeros([sparse.shape[0], relation.shape[-1]]).to(device)
if sum == 'add':
for key in sparse_dict.keys():
for value in sparse_dict[key]:
tmp = torch.mul(input[value[0]], relation[value[1]]) + output[key]
print('tmp:', tmp)
output[key] = tmp
elif sum == 'max':
for key in sparse_dict.keys():
for value in sparse_dict[key]:
tmp = output[key].clone()
output[key] = torch.maximum(tmp, torch.mul(input[value[0]], relation[value[1]]))
elif sum == 'min':
for key in sparse_dict.keys():
for value in sparse_dict[key]:
tmp = output[key].clone()
output[key] = torch.minimum(tmp, torch.mul(input[value[0]], relation[value[1]]))
else:
raise NotImplementedError
return output
Hi! I followed the instruction to install the packages. But now I'm getting an ImportError when reproducing the results. The error is as following. I also tried
rm -r ~/.cache/torch_extensions/*
as suggested in Readme but that will cause more error.I'm using torch1.11+cuda11.3 \ torchdrug0.1.2
Do you know how to dealing with this? Any help is appreciated! By the way, in other issues I noticed an enviroment.yml would be released. Where can I find that? Thanks!