ethanfetaya / NRI

Neural relational inference for interacting systems - pytorch
MIT License
732 stars 157 forks source link

About edge_accuracy() in utils.py #35

Closed chocolates closed 2 years ago

chocolates commented 2 years ago

First, thanks a lot for sharing this great repo. I have two questions with the computation of relation prediction accuracy:

  1. Suppose the model is trained and we only want to evaluate the trained model. The accuracy can be different with different values for the batch-size parameter (however, it should not be influenced by batch-size because the model does not change), especially when the number of test examples is not very large. The reason could be that not all batches have batch-size examples (if num_test_example % batch-size != 0). I feel it is better that edge_accuracy() in utils.py returns the average accuracy and the number of examples in this batch, and then compute the average in the main script by taking the division.
  2. (If I understand correctly), we (or you) do not care about the ''absolute'' class label. It is more like clustering instead of classification. So, for the two-relation cases, the accuracy should be max(acc, 1.0-acc)? Besides, I wonder do you have some ideas to compute the accuracy with multiple (>2) relation cases? (the current edge_accuracy() function seems only suitable for two-relation case).
tkipf commented 2 years ago

Thanks a lot for your comments/questions!

Regarding 1): You're right, the way we accumulate metrics doesn't correctly account for incomplete batches. This should only make a minor difference in the evaluation scores, but please feel free to submit a pull request.

Re: 2) That's correct. For multiple relation cases, you could either iterate over all possible permutations (but this quickly becomes impractical) or use a permutation-invariant clustering similarity score for evaluation such as ARI (https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html) or AMI (https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_mutual_info_score.html).

chocolates commented 2 years ago

Thanks very much for the confirmation/comments!