pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.39k stars 3.67k forks source link

PGExplainer error - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_mm) #6620

Closed Peter-obi closed 1 year ago

Peter-obi commented 1 year ago

🐛 Describe the bug

I trained a GCN model on Goggle collab for multi-class classification and was trying to get explanations for the edge values using the PGExplainer. All the code runs well while training the model on GPU but when I want to run the PGExplainer, for some reason some of the tensors are on the cpu despite adding the '.to(device)' to all inputs. Kindly find code below

from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
import torch.nn.functional as F
import torch.nn as nn

class GCN(nn.Module):
    def __init__(self, input_dim, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(42)
        self.conv1 = GCNConv(input_dim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, 3)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = gnn.global_mean_pool(x, batch)
        x = self.lin(x)
        return x
device = torch.device("cuda")
# Define criterion and optimizer
model = GCN(input_dim = 3, hidden_channels = 256)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
model = model.to(device)
model = model.double()
# Train and evaluate the model
for epoch in range(num_epochs):
    model.train()
    correct = 0
    for data in train_dataloader:
        optimizer.zero_grad()
        out = model(data.x.to(device), data.edge_index.to(device).long(), data.batch.to(device))
        loss = criterion(out, data.y.to(device))
        loss.backward()
        optimizer.step()
        pred = out.argmax(dim=1)
        correct += int((pred == data.y.to(device)).sum())
        #print(correct)
    train_acc = correct / len(train_dataloader.dataset)
    model.eval()
    correct = 0
    for data in val_dataloader:
        out = model(data.x.to(device), data.edge_index.to(device).long(), data.batch.to(device))
        pred = out.argmax(dim=1)
        correct += int((pred == data.y.to(device)).sum())
    val_acc = correct / len(val_dataloader.dataset)
    print(f'Epoch: {epoch}, Train Accuracy: {train_acc:.4f}, Validation Accuracy: {val_acc:.4f}')

Implement and run PGExplainer

from torch_geometric.explain import Explainer, PGExplainer
explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=10, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='raw',
    ),
    # Include only the top 10 most important edges:
    threshold_config=dict(threshold_type='topk', value=10),
)
# PGExplainer needs to be trained separately since it is a parametric
# explainer i.e it uses a neural network to generate explanations:
for epoch in range(10):
    for data in val_dataloader:
        # Move the data to the chosen device
        data = data.to(device)
        x = data.x
        x = x.to(device)

        edge_index = data.edge_index
        edge_index = edge_index.to(device)

        target = data.y
        target = target.to(device)

        batch = data.batch
        batch = batch.to(device)

        # Pass the data to the model's train method
        loss = explainer.algorithm.train(epoch, model, x, edge_index, target=target, batch=batch)
# Generate the explanation for a particular graph:
explanation = explainer(dataset[0].x, dataset[0].edge_index)
print(explanation.edge_mask)

Full traceback

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-37-42a83a465a91>](https://u3a1yjs3ki-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230202-060047-RC00_506569794#) in <module>
     35 
     36         # Pass the data to the model's train method
---> 37         loss = explainer.algorithm.train(epoch, model, x, edge_index, target=target, batch=batch)
     38 # Generate the explanation for a particular graph:
     39 explanation = explainer(dataset[0].x, dataset[0].edge_index)

4 frames
[/usr/local/lib/python3.8/dist-packages/torch_geometric/explain/algorithm/pg_explainer.py](https://u3a1yjs3ki-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230202-060047-RC00_506569794#) in train(self, epoch, model, x, edge_index, target, index, **kwargs)
    127 
    128         inputs = self._get_inputs(z, edge_index, index)
--> 129         logits = self.mlp(inputs).view(-1)
    130         edge_mask = self._concrete_sample(logits, temperature)
    131         set_masks(model, edge_mask, edge_index, apply_sigmoid=True)

[/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py](https://u3a1yjs3ki-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230202-060047-RC00_506569794#) in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py](https://u3a1yjs3ki-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230202-060047-RC00_506569794#) in forward(self, input)
    202     def forward(self, input):
    203         for module in self:
--> 204             input = module(input)
    205         return input
    206 

[/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py](https://u3a1yjs3ki-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230202-060047-RC00_506569794#) in _call_impl(self, *input, **kwargs)
   1210             input = bw_hook.setup_input_hook(input)
   1211 
-> 1212         result = forward_call(*input, **kwargs)
   1213         if _global_forward_hooks or self._forward_hooks:
   1214             for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

[/usr/local/lib/python3.8/dist-packages/torch_geometric/nn/dense/linear.py](https://u3a1yjs3ki-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20230202-060047-RC00_506569794#) in forward(self, x)
    129             x (torch.Tensor): The input features.
    130         """
--> 131         return F.linear(x, self.weight, self.bias)
    132 
    133     @torch.no_grad()

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_addmm)

Environment

rusty1s commented 1 year ago

You need to move PGExplainer to the device as well, e.g.:


explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=10, lr=0.003).to(device),
rusty1s commented 1 year ago

Corresponding test: https://github.com/pyg-team/pytorch_geometric/pull/6624 Closing this issue for now, feel free to re-open if you still have doubts.

Peter-obi commented 1 year ago

Thank you! It worked