Closed viz27 closed 4 years ago
Thank you for pointing it out!
I've fixed the problem in utils.accuracy()
and I also changed line 149 from acc_test = accuracy(output[[target_node]], labels[target_node])
to acc_test = (output.argmax(1)[target_node] == labels[target_node])
. Think it is better to check whether the target is misclassified in this way.
Thanks again.
Thank you! This library has been a great help for my project.
Thanks, I am glad it helps : )
In line 149, accuracy is calculated as
acc_test = accuracy(output[[target_node]], labels[target_node])
butlabels[target_node]
is an integer here. Theaccuracy()
function in utils.py expects a list in labels argument. If an integer is passed instead, sayn
-labels = torch.LongTensor(labels)
will create a list of size n instead of a list with single item. Therefore in many instances acc_test holds fractional values instead of the expected 0 or 1. Either accuracy should be calculated asacc_test = accuracy(output[[target_node]], [labels[target_node]])
oraccuracy()
function be modified to accept integer values in label parameter.