pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.95k stars 3.61k forks source link

Using PyG for Factor Graphs #2012

Open legohyl opened 3 years ago

legohyl commented 3 years ago

❓ Questions & Help

Hi there, I was wondering if we could use PyG for Factor Graphs? I'm trying to integrate some form of belief propagation in Factor Graphs together with GNNs, and would like to use the same interface for the belief propagation. I've structured the data such that it looks like a bipartite graph from networkx, but can't really figure out how to solve the MessagePassing protocol.

Here's what it looks like: image

The red nodes are the variable nodes, with a single-valued feature on them. Blue nodes are factor nodes which have the following form: m_{fac_i -> varj} = -min[2] m{var_k -> fac_i, where k \in N(j)}} This means that my factor basically returns the 2nd smallest value in the neighborhood, where the neighborhood does not include the message coming from the variable I'm sending to (variable j in this case).

I understand that at the moment, PyG only has mean, max, sum. Hence, min in this case can be turned to max pretty easily. But the problem I see is that I can't return the 2nd highest max value, is there a solution to this? E.g. creating a custom aggregation function?

Secondly, I'm having a bit of issues trying to execute the variable-to-factor and factor-to-variable messages. At the moment, I've represented the data point as a type of BipartiteData, where x_s are my variable features, and x_t is just the identity values (since my factors don't have features, I only want the messages). m_{var_i -> facj} = sum{k \in N(i) \ j} m_{fac_k -> var_i}

Here's my implementation:

class BipartiteData(pyg.data.Data):
    def __init__(self, edge_index, x_s, x_t):
        super(BipartiteData, self).__init__()
        self.edge_index = edge_index
        self.edge_index_T = torch.stack([edge_index[1], edge_index[0]], dim=0)
        self.x_s = x_s
        self.x_t = x_t

So, I'm thinking of defining 2 message passing classes, a Var2Fac and Fac2Var class to pass the messages accordingly. But I'm still a bit lost on how I can do this.

rusty1s commented 3 years ago

Hi, and thanks for your questions.

  1. You can create a custom aggregation function by overwriting:
    def aggregate(self, inputs, index):
    ...

    but from a standpoint of efficiency, it is quite hard to retrieve the 2nd highest max value. This would require us to support some kind of scatter_top_k implementation, which is not supported at the moment. Your best bet is, in case efficiency is not a problem, to manually implement it via pure Python/PyTorch, e.g.:

    def aggregate(self, inputs, index, dim_size):
    out = torch.zeros(dim_size, inputs.size(1))
    for i in range(dim_size):
        out[i] = inputs[index == i].sort(dim=0)[0][1]
    return out
  2. You only need to define two message passing classes if messages between different node types should be treated differently. Otherwise, something like that will work:
    conv1 = SAGEConv((num_var_features, num_fac_features), 256)
    conv2 = SAGEConv((num_fac_fatures, num_var_features), 256)
    new_x_t = conv1((x_s, x_t), edge_index)
    new_x_s = conv2((x_t, x_s), edge_index_T) 

    I hope that this is useful to you.

CarloLucibello commented 3 years ago

Using https://pytorch.org/docs/stable/generated/torch.topk.html instead of sorting could help

legohyl commented 3 years ago

hey @rusty1s , sorry for pinging an old question, I was wondering if there's any tutorial on how to construct the MessagePassing layer for bipartite graphs? I'm just trying out a simple sum on one set of variables first but I meet a fair bit of errors.

class Var2Fac(MessagePassing):
    def __init__(self):
        super(Var2Fac, self).__init__(aggr='add')

    def forward(self, x_n, x_m, edge_index, size):
        return self.propagate(edge_index, x=(x_n, x_m), size=size)

    def message(self, x_j):
        return x_j

Running a simple call like

conv1 = Var2Fac()
conv1.forward(x_n=data.x_s, x_m=data.x_t, edge_index=data.edge_index, size=(n,m))

gives me this error


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-94-4b5a06462be0> in <module>
      1 conv1 = Var2Fac()
----> 2 conv1.forward(x_n=data.x_s, x_m=data.x_t, edge_index=data.edge_index, size=(n,m))

<ipython-input-92-25895c357148> in forward(self, x_n, x_m, edge_index, size)
      5 
      6     def forward(self, x_n, x_m, edge_index, size):
----> 7         return self.propagate(edge_index, x=(x_n, x_m), size=size)
      8 
      9     def message(self, x_j):

/opt/conda/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py in propagate(self, edge_index, size, **kwargs)
    231         # Otherwise, run both functions in separation.
    232         elif isinstance(edge_index, Tensor) or not self.fuse:
--> 233             coll_dict = self.__collect__(self.__user_args__, edge_index, size,
    234                                          kwargs)
    235 

/opt/conda/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py in __collect__(self, args, edge_index, size, kwargs)
    150                     assert len(data) == 2
    151                     if isinstance(data[1 - dim], Tensor):
--> 152                         self.__set_size__(size, 1 - dim, data[1 - dim])
    153                     data = data[dim]
    154 

/opt/conda/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py in __set_size__(self, size, dim, src)
    117         if the_size is None:
    118             size[dim] = src.size(self.node_dim)
--> 119         elif the_size != src.size(self.node_dim):
    120             raise ValueError(
    121                 (f'Encountered tensor with size {src.size(self.node_dim)} in '

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)
rusty1s commented 3 years ago

Are your node features one-dimensional? This should be fixable by reshaping x_n and x_m to shapes [-1, 1].

legohyl commented 3 years ago

Yes, I spotted just awhile back! So sorry for bothering on this. I'll continue experimenting and see how to make this work; the pure torch implementation seems to be rather slow, seeing if there's a way PyG can do it efficiently.

legohyl commented 3 years ago

Referring to your suggestion on the code snippet below, what is dim_size? I was thinking of implementing some form of torch.topk functionality instead. Could you walk me through what are the parameters here? My current thinking is inputs are the neighboring vector inputs, index has the same number of rows as inputs, where its basically an integral indication of which node it belongs to. What is passed as dim_size here?

def aggregate(self, inputs, index, dim_size):
    out = torch.zeros(dim_size, inputs.size(1))
    for i in range(dim_size):
        out[i] = inputs[index == i].sort(dim=0)[0][1]
    return out
sasan73 commented 3 years ago

Hello, Thanks for the great library.

I want to build two graph convolution layers that flow messages from each set of node to one another. My implementation looks something like this:

class BipartiteGCN(MessagePassing):
  def __init__(self, in_channel, out_channel):
    super(BipartiteGCN, self).__init__(aggr = 'add')
    self.lin = nn.Linear(in_channel, out_channel)

  def forward(self, x, edge_index, N, M):

    x_first, x_second = x
    x_first = self.lin(x_first)

    row, col = edge_index
    deg_u = degree(row)
    deg_i = degree(col)
    new_deg_u = deg_u.pow(-.5)
    new_deg_i = deg_i.pow(-.5)
    norm = new_deg_u[row] * new_deg_i[col]
    return self.propagate(edge_index, x=(x_first, x_second), norm = norm, size = (N, M))

  def message(self, x_j, norm):

    return norm.view(-1, 1) * x_j
gconv_u = BipartiteGCN(x_i.shape[1], 256)
gconv_i = BipartiteGCN(x_u.shape[1], 256)

new_x_u = gconv_u((x_i, x_u), edge_index[torch.tensor([1, 0])], x_i.shape[0], x_u.shape[0])
new_x_i = gconv_i((x_u, x_i), edge_index, x_u.shape[0], x_i.shape[0])

is the implementation correct?

rusty1s commented 3 years ago

Yes, this looks absolutely correct. Note that you may want to integrate central root information into your GNN, i.e., x_second is never really used in your layer.

Edit: You can also drop the size argument as it is not necessary to determine the shape of your graph. This information is already present in x.

legohyl commented 3 years ago

hey @rusty1s , sorry to ping on this again, but it seems like if I define 2 separate convolutions, when Fac2Var passes a message back to Var2Fac nodes, it doesn't take into account the original message sent to it, when this message should be subtracted. My simple implementation is as follows:

# Construct var to fac and fac to var layers
class Var2Fac(MessagePassing):
    def __init__(self):
        super(Var2Fac, self).__init__(aggr='add')

    def forward(self, x_n, x_m, edge_index, size):
        return self.propagate(edge_index, x=(x_n, x_m), size=size)

    def message(self, x_j):
        return x_j

class Fac2Var(MessagePassing):
    def __init__(self):
        super(Fac2Var, self).__init__(aggr=None)

    def forward(self, x_n, x_m, edge_index, size):
        return self.propagate(edge_index, x=(x_n, x_m), size=size)

    def aggregate(self, inputs, index, dim_size):
        out = torch.zeros(dim_size, inputs.size(1))
        for i in range(dim_size):
            out[i] = torch.topk(inputs[index == i], k=2, dim=0, largest=False)[0][1, :]

        return out

Messages are then passed as

conv1 = Fac2Var()
new_x_m = conv1.forward(x_n=x_s, x_m=x_t, edge_index=data.edge_index, size=(n,m))

conv2 = Var2Fac()
new_x_n = conv2.forward(x_n=x_t, x_m=x_s, edge_index=data.edge_index_T, size=(m,n))

# conv1 = Fac2Var()
new2_x_m = conv1.forward(x_n=new_x_n, x_m=new_x_m, edge_index=data.edge_index, size=(n,m))

# conv2 = Var2Fac()
new2_x_n = conv2.forward(x_n=new_x_m, x_m=new_x_n, edge_index=data.edge_index_T, size=(m,n))

My answer for the features is not the same as compared to a manual implementation in torch. Basically, when I hope to achieve is this, given some nodes x_s and some nodes x_t, the features in x_t get updated by taking the 2nd highest incoming message from the x_s nodes. Now, given the way Fac2Var is implemented, I end up taking the 2nd highest message amongst all x_s nodes. Instead, suppose I wish to send back a message to node i from x_s, I should be taking the 2nd highest message sent from all x_s excluding the message sent from node i. I'm not really sure how to implement this part right here. It seems that the layer has to be aware of neighbors, as well as neighbors of neighbors?

rusty1s commented 3 years ago

If you want to filter out some messages, how about dropping those messages in the for loop in aggregate? Is this what you are trying to achieve?

legohyl commented 3 years ago

Hmm, are the messages in the aggregate function aware of the neighbors? Because I wish to find the 2nd largest message amongst the neighborhood for each node.

rusty1s commented 3 years ago

But you already do this, aren't you?

def aggregate(self, inputs, index, dim_size):
        out = torch.zeros(dim_size, inputs.size(1))
        for i in range(dim_size):
            out[i] = torch.topk(inputs[index == i], k=2, dim=0, largest=False)[0][1, :]

This will select, for each neighborhood of node i, the second highest features across the neighborhood.

legohyl commented 3 years ago

Yes, that's true. But this 2nd largest message needs to depend on the sending node. For example, if a factor A receives messages from nodes X, Y, Z, the message that factor A sends to node X is the 2nd largest message between {Y,Z}. Likewise, the message that factor A sends to node Y is the 2nd largest message between {X,Z}.

rusty1s commented 3 years ago

I see. This complicates things a lot, and this is something that isn't well covered by the PyG MessagePassing interface.

I guess this functionality needs to be implemented in message directly. Sadly, I don't know of any good way to parallelize it though, which therefore might require to utilize Python loops.

This may look like the following (untested):

def message(self, x_i, x_j, edge_index_i, edge_index_j):
    outs = []
    for e in range(x_i.size(0)):
        mask = (edge_index_i == edge_index_i[e]) & (edge_index_j != edge_index_j[e])
        out = x_j[mask].topk(k=2, largest=False)[1]
        outs.append(out)
legohyl commented 3 years ago

This is interesting! I'll check this out and try it later on. It's quite cool how your variables are not explicitly stated, but can be called accordingly.

CarloLucibello commented 3 years ago

This paper https://arxiv.org/pdf/1705.08415.pdf defines an auxiliary graph where nodes are oriented edges of the original graph and the adjacency matrix is the non-backtracking matrix