Closed Peter-obi closed 1 year ago
You need to move PGExplainer
to the device as well, e.g.:
explainer = Explainer(
model=model,
algorithm=PGExplainer(epochs=10, lr=0.003).to(device),
Corresponding test: https://github.com/pyg-team/pytorch_geometric/pull/6624 Closing this issue for now, feel free to re-open if you still have doubts.
Thank you! It worked
🐛 Describe the bug
I trained a GCN model on Goggle collab for multi-class classification and was trying to get explanations for the edge values using the PGExplainer. All the code runs well while training the model on GPU but when I want to run the PGExplainer, for some reason some of the tensors are on the cpu despite adding the '.to(device)' to all inputs. Kindly find code below
Implement and run PGExplainer
Full traceback
Environment
conda
,pip
, source): sourcetorch-scatter
):Torch scatter: 2.1.0