Open rgemulla opened 4 years ago
This would be really helpful!
I've written some really back-of-the-envelope code to emulate this behavior
dataset_name = 'fb15k-237'
model_name = 'rescal'
dataset_and_model = dataset_name + '-' + model_name
model = kge.model.KgeModel.load_from_checkpoint(
./local/best/' + model_name + '/' + dataset_and_model + '.pt')
s = model.dataset.split('valid').select(1, 0)
p = model.dataset.split('valid').select(1, 1)
o = model.dataset.split('valid').select(1, 2)
scores = model.score_sp(s, p)
raw_ranks = [torch.sum(score_array > score_array[o[i]], dtype=torch.long) for (i, score_array) in enumerate(scores)]
num_ties = [torch.sum(score_array == score_array[o[i]], dtype=torch.long) for (i, score_array) in enumerate(scores)]
final_ranks = [(raw_rank + (num_ties[j] // 2)) for (j, raw_rank) in enumerate(raw_ranks)]
Update: first part of this implemented as part of #94. CLI still needs to be updated.
Right now, only possible with folder structure as used in training.