DSE-MSU / DeepRobust

A pytorch adversarial library for attack and defense methods on images and graphs
MIT License
994 stars 192 forks source link

GCNSVD predict() #47

Closed kinkunchan closed 3 years ago

kinkunchan commented 3 years ago

As GCNSVD model will process the input adjacency matrix using truncatedSVD(), we cannot directly use the predict() function inherited from GCN.

This is my personal implementation,

 def predict(self, features=None, adj=None):
        self.eval()
        if features is None and adj is None:
            return self.forward(self.features, self.adj_norm)
        else:
            modified_adj = self.truncatedSVD(adj, k=self.k)
            features, modified_adj, labels = utils.to_tensor(features, modified_adj, self.labels, device=self.device)
            if utils.is_sparse_tensor(modified_adj):
                modified_adj = utils.normalize_adj_tensor(modified_adj, sparse=True)
            else:
                modified_adj = utils.normalize_adj_tensor(modified_adj)

            return self.forward(features, modified_adj)
ChandlerBang commented 3 years ago

Thanks for the suggestion.

Typically self.predict() is used by self.test() where we do not specify the arguments features and adj. If you hope to pass some features, adj pair into the function, you can do it outside the predict() as follows,

prediction1 = model.predict()
processed_adj = model.truncatedSVD(perturbed_adj, k=20)
# prediction2 is the same as prediction1
prediction2 = model.predict(features, processed_adj)

Similarly we can do that in GCNJaccard. I think I am going to keep the original version of predict(). Thank you again for your suggestions and feel free to let me know if you have other concerns.

jiong-zhu commented 3 years ago

@ChandlerBang I see your point, but I would still vote to include an overridden version of the predict function for GCNSVD, just to keep the consistency of interfaces among different models. If kept unchanged, users using the predict() function for GCNSVD may think they are getting predictions on the GCNSVD, but in fact they are getting results from GCN.

ChandlerBang commented 3 years ago

@jiong-zhu Yeah, that makes sense. I've updated the code as you suggested to avoid confusion. See details in commit https://github.com/DSE-MSU/DeepRobust/commit/000f86d124f75ce3b21e47588a79ffe24fb94bf7. Thank you for the advice, Jiong and Mark!

jiong-zhu commented 3 years ago

Thank you!