EdisonLeeeee / GreatX

A graph reliability toolbox based on PyTorch and PyTorch Geometric (PyG).
MIT License
83 stars 11 forks source link

SG Attack example cannot run as expected on cuda #10

Closed beiyanpiki closed 1 year ago

beiyanpiki commented 1 year ago

Hello, I got some error when I run SG Attack's example code on cuda device:

Traceback (most recent call last):
  File "src/test.py", line 50, in <module>
    attacker.attack(target)
  File "/greatx/attack/targeted/sg_attack.py", line 212, in attack
    subgraph = self.get_subgraph(target, target_label, best_wrong_label)
  File "/greatx/attack/targeted/sg_attack.py", line 124, in get_subgraph
    self.label == best_wrong_label)[0].cpu().numpy()
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cpu!

I found self.label is on cuda device, but best_wrong_label is on cpu. https://github.com/EdisonLeeeee/GreatX/blob/73eac351fdae842dbd74967622bd0e573194c765/greatx/attack/targeted/sg_attack.py#L123-L124

I remove line94 .cpu(), everything is going well and no error report

https://github.com/EdisonLeeeee/GreatX/blob/73eac351fdae842dbd74967622bd0e573194c765/greatx/attack/targeted/sg_attack.py#L94-L96

I found there is a commit that adds .cpu end of line 94, so I dont know it's a bug or something else🤨

EdisonLeeeee commented 1 year ago

Thanks for catching! Fixed via 49675e30ed38b2518f7a1b85b667de5f4a0c9a7c

I found there is a commit that adds .cpu end of line 94, so I dont know it's a bug or something else

It's an intended behavior to save memory on GPU :)