pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.33k stars 3.66k forks source link

GCN based Pooling Layer #2269

Open hkim716 opened 3 years ago

hkim716 commented 3 years ago

Hi Matt,

I have question about designing a simple pooling layer from some GCNConv layers. I would like to create a supervised regression model. I have many number of graphs in the datasets, and each graph has x=[100,1] and y=[2,1] as label.

class MyModel(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MyModel, self).__init__()
        self.conv1 = pyg_nn.GCNConv(in_channels, 8)
        self.conv2 = pyg_nn.GCNConv(8, 16)
        self.conv3 = pyg_nn.GCNConv(16, 8)
        self.conv4 = pyg_nn.GCNConv(8, out_channels)

    def forward(self, data):

        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = x.float()
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        emb = x
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        x = x.view(len(batch)//100, 100, 2)

        return emb, torch.mean(x, dim=1)

    def loss(self, pred, label):

        return F.mse_loss(pred, label.to(torch.float32))

From my understanding, when I use in_channels=100 and out_channels = 2, torch.mean(x, dim=1) will perform like a pooling layer that reduce the input dimension from 100 to 2. Is that right? Could it be considered as MeanPooling?

rusty1s commented 3 years ago

Yes, indeed. This is exactly what is happening. One typically also applies a final linear or non-linear transformation on top of the pooled features, e.g., via an MLP. In order to make use of mini-batching capabilities, you need to swap out the torch.mean call with our torch_geometric.global_mean_pool functionality.