DSE-MSU / DeepRobust

A pytorch adversarial library for attack and defense methods on images and graphs
MIT License
994 stars 192 forks source link

Issue with test_nettack.py #18

Closed viz27 closed 4 years ago

viz27 commented 4 years ago

In line 149, accuracy is calculated as acc_test = accuracy(output[[target_node]], labels[target_node]) but labels[target_node] is an integer here. The accuracy() function in utils.py expects a list in labels argument. If an integer is passed instead, say n - labels = torch.LongTensor(labels) will create a list of size n instead of a list with single item. Therefore in many instances acc_test holds fractional values instead of the expected 0 or 1. Either accuracy should be calculated as acc_test = accuracy(output[[target_node]], [labels[target_node]]) or accuracy() function be modified to accept integer values in label parameter.

ChandlerBang commented 4 years ago

Thank you for pointing it out!

I've fixed the problem in utils.accuracy() and I also changed line 149 from acc_test = accuracy(output[[target_node]], labels[target_node]) to acc_test = (output.argmax(1)[target_node] == labels[target_node]). Think it is better to check whether the target is misclassified in this way.

Thanks again.

viz27 commented 4 years ago

Thank you! This library has been a great help for my project.

ChandlerBang commented 4 years ago

Thanks, I am glad it helps : )