rampasek / GraphGPS

Recipe for a General, Powerful, Scalable Graph Transformer
MIT License
643 stars 114 forks source link

Questions about the output channel #28

Closed HelloWorldLTY closed 1 year ago

HelloWorldLTY commented 1 year ago

Hi, I notice that there is a version in pyg implementation. https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GPSConv.html#torch_geometric.nn.conv.GPSConv I wonder why it does not have the size of output channel? I think in the paper it mentioned that the model will update the dimensions of embeddings.

image

Thanks a lot.

rampasek commented 1 year ago

Hello Tianyu,

Thanks for pointing out the PyG implementation, I was not aware of it! However, you have to raise any issues regarding PyG at PyG github repo.

What do you actually mean by "I wonder why it does not have the size of output channel? I think in the paper it mentioned that the model will update the dimensions of embeddings."? From a quick look at their code, what pops up to me is that they are not supporting MPNNs that update edge embeddings, unlike the official GPS implementation here.

Best, Ladislav

HelloWorldLTY commented 1 year ago

Ok, thanks a lot. I will choose to post this question under pyg. If I have any update, I will let you know!

HelloWorldLTY commented 1 year ago

Hi Ladislav, if my understand is correct, could we use GPSConv like:

import torch
from torch.nn import Module, Linear
from antisymmetric_conv import GPSConv
from torch_geometric.data import Data

class GPS(nn.Module):

    def __init__(self, 
                 input_dim, output_dim, hidden_dim, ) -> None:
        super(GPS, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.phi = phi
        self.num_iters = num_iters
        self.epsilon = epsilon
        self.gamma = gamma
        self.act = act
        self.act_kwargs = act_kwargs
        self.bias = bias
        self.emb = nn.Linear(self.input_dim, self.hidden_dim)
        self.conv = GPSConv()
        self.readout = nn.Linear(self.hidden_dim, self.output_dim)

    def forward(self, data: Data) -> torch.Tensor:

        x, edge_index, edge_weight = data.x, data.edge_index
        x = self.emb(x)
        x = self.conv(x, edge_index)
        x = self.readout(x)

        return x

I can update the embeddings directly.