pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.95k stars 3.61k forks source link

MessagePassing bugs when processing bipartite graph #9583

Open liulizhi1996 opened 1 month ago

liulizhi1996 commented 1 month ago

🐛 Describe the bug

I find a weird bug of MessagePassing module when processing bipartite graph.

Run code snippet in Google Colab notebook

Here is my code:

import torch
import torch.nn as nn
from torch_geometric.datasets import AmazonBook
from torch_geometric.nn.conv import MessagePassing

class GCN(MessagePassing):
    def __init__(self, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

    def forward(self, x_src, x_dst, edge_index):
        out = self.propagate(edge_index, x=(x_src, x_dst))
        return out

    def message(self, x_j):
        return x_j

dataset = AmazonBook(root='.')
data = dataset[0]
num_user, num_item = data['user']['num_nodes'], data['book']['num_nodes']
edge_index = torch.cat([data['user', 'rates', 'book']['edge_index'],
                        data['user', 'rates', 'book']['edge_label_index']], dim=1)

def _setdiff1d(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # Ensure x and y are 1-dimensional
    if x.dim() != 1 or y.dim() != 1:
        raise ValueError("Both x and y must be 1-dimensional tensors.")
    # Find unique elements in x
    unique_x = torch.unique(x)
    # Create a mask of elements in unique_x that are not in y
    mask = torch.isin(unique_x, y, invert=True)
    # Return the elements in unique_x that are not in y
    result = unique_x[mask]
    return result

def split_data(edges, split_ratio=0.8):
    # Split train & test edges
    num_edges = edges.size(1)
    perm = torch.randperm(num_edges)
    train_size = int(num_edges * split_ratio)
    train_edges = edges[:, perm[:train_size]]
    test_edges = edges[:, perm[train_size:]]

    # Filter out cold start users & items from test edges
    cold_start_users = _setdiff1d(test_edges[0], train_edges[0])
    if cold_start_users.size(0) > 0:
        mask = ~torch.isin(test_edges[0], cold_start_users)
        test_edges = test_edges[:, mask]
    cold_start_items = _setdiff1d(test_edges[1], train_edges[1])
    if cold_start_items.size(0) > 0:
        mask = ~torch.isin(test_edges[1], cold_start_items)
        test_edges = test_edges[:, mask]

    return train_edges, test_edges

edge_index, _ = split_data(edge_index)

print(num_user, num_item)
print(edge_index.shape)

x_user = nn.Embedding(num_user, 64)
x_item = nn.Embedding(num_item, 64)

model_i2u = GCN(flow='target_to_source')
h_user = model_i2u(x_user.weight, x_item.weight, edge_index)
print(h_user.shape)

model_u2i = GCN(flow='source_to_target')
h_item = model_u2i(x_user.weight, x_item.weight, edge_index)
print(h_item.shape)

If I run the above code in the notebook cell, all runs well and no bugs reported.

Run Python files in terminal

If I split the above code into two Python script files, the strange bug will be raised.

File gcn.py:

from torch_geometric.nn.conv import MessagePassing

class GCN(MessagePassing):
    def __init__(self, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

    def forward(self, x_src, x_dst, edge_index):
        out = self.propagate(edge_index, x=(x_src, x_dst))
        return out

    def message(self, x_j):
        return x_j

File main.py:

import torch
import torch.nn as nn
from torch_geometric.datasets import AmazonBook

from gcn import GCN

dataset = AmazonBook(root='.')
data = dataset[0]
num_user, num_item = data['user']['num_nodes'], data['book']['num_nodes']
edge_index = torch.cat([data['user', 'rates', 'book']['edge_index'],
                        data['user', 'rates', 'book']['edge_label_index']], dim=1)

def _setdiff1d(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # Ensure x and y are 1-dimensional
    if x.dim() != 1 or y.dim() != 1:
        raise ValueError("Both x and y must be 1-dimensional tensors.")
    # Find unique elements in x
    unique_x = torch.unique(x)
    # Create a mask of elements in unique_x that are not in y
    mask = torch.isin(unique_x, y, invert=True)
    # Return the elements in unique_x that are not in y
    result = unique_x[mask]
    return result

def split_data(edges, split_ratio=0.8):
    # Split train & test edges
    num_edges = edges.size(1)
    perm = torch.randperm(num_edges)
    train_size = int(num_edges * split_ratio)
    train_edges = edges[:, perm[:train_size]]
    test_edges = edges[:, perm[train_size:]]

    # Filter out cold start users & items from test edges
    cold_start_users = _setdiff1d(test_edges[0], train_edges[0])
    if cold_start_users.size(0) > 0:
        mask = ~torch.isin(test_edges[0], cold_start_users)
        test_edges = test_edges[:, mask]
    cold_start_items = _setdiff1d(test_edges[1], train_edges[1])
    if cold_start_items.size(0) > 0:
        mask = ~torch.isin(test_edges[1], cold_start_items)
        test_edges = test_edges[:, mask]

    return train_edges, test_edges

edge_index, _ = split_data(edge_index)

print(num_user, num_item)
print(edge_index.shape)

x_user = nn.Embedding(num_user, 64)
x_item = nn.Embedding(num_item, 64)

model_i2u = GCN(flow='target_to_source')
h_user = model_i2u(x_user.weight, x_item.weight, edge_index)
print(h_user.shape)

model_u2i = GCN(flow='source_to_target')
h_item = model_u2i(x_user.weight, x_item.weight, edge_index)
print(h_item.shape)

Now I type command python main.py in the terminal, and the logs are as follows:

52643 91599
torch.Size([2, 2387286])
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/conv/message_passing.py", line 317, in _index_select_safe
    return src.index_select(self.node_dim, index)
IndexError: index out of range in self

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/content/main.py", line 58, in <module>
    h_user = model_i2u(x_user.weight, x_item.weight, edge_index)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/gcn.py", line 10, in forward
    out = self.propagate(edge_index, x=(x_src, x_dst))
  File "/tmp/gcn_GCN_propagate_om1q2x5u.py", line 129, in propagate
    kwargs = self.collect(
  File "/tmp/gcn_GCN_propagate_om1q2x5u.py", line 76, in collect
    x_j = self._index_select(_x_0, edge_index_j)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/conv/message_passing.py", line 313, in _index_select
    return self._index_select_safe(src, index)
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/nn/conv/message_passing.py", line 328, in _index_select_safe
    raise IndexError(
IndexError: Found indices in 'edge_index' that are larger than 52642 (got 91598). Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 52643) in your node feature matrix and try again.

I guess it may be caused by _set_jittable_templates() in MessagePassing, and the template doesn't support bipartite graph. Please fix it! Thanks!

Versions

torch_geometric==2.5.3

liulizhi1996 commented 1 month ago

Moreover, such bugs are not reported in torch_geometric==2.4.0.