pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.15k stars 3.64k forks source link

None edge_attr assertion in GeneralConv #9530

Open bgeier opened 2 months ago

bgeier commented 2 months ago

🐛 Describe the bug

The GeneralConv layer is raising an assertion when only a node array (x) and adjacency (edge_index) are provided. I expect the layer to return a result when I don't provide edge_attr (default is None).

Context: I'm writing unit tests for a model that is composed of many layers. I've worked back to a core torch_geometric layer that is raising an assert. I get the assertion when providing a basic data input. For example,

from torch_geometric.datasets import FakeDataset
from torch_geometric.nn import GeneralConv

gnn = GeneralConv(in_channels=100, out_channels=100)

dataset = FakeDataset(
        num_graphs=32 * 4,  # 4 batches of 32
        avg_num_nodes=20,
        num_channels=100,
        num_classes=2,
        edge_dim=1,
        is_undirected=False,
)

 gnn(dataset[0].x, dataset[0].edge_index)

The forward pass to gnn is raising an AssertionError via assert edge_attr is not None. I'm having trouble locating the file asserting the error. Is this a legit bug or user error? Any help would be much appreciated! Thanks!

Here's the traceback from the above test.

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/home/dev/.venvs/project/lib/python3.11/site-packages/torch/nn/modules/module.py:1532: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/home/dev/.venvs/project/lib/python3.11/site-packages/torch/nn/modules/module.py:1541: in _call_impl
    return forward_call(*args, **kwargs)
/home/dev/.venvs/project/lib/python3.11/site-packages/torch_geometric/nn/conv/general_conv.py:155: in forward
    out = self.propagate(edge_index, x=x, size=size, edge_attr=edge_attr)
/tmp/torch_geometric.nn.conv.general_conv_GeneralConv_propagate_7jgj7uxt.py:163: in propagate
    kwargs = self.collect(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = GeneralConv(100, 100)
edge_index = tensor([[ 0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,
          2,  2,  3,  3,  3,  3,  3, ...  5,  6, 10, 16,  0,  3,  5,  9, 11, 14, 16,  5,  7,  8, 10, 13,
         15,  1,  4,  6, 10,  0,  7, 10, 14, 16, 17]])
x = (tensor([[-1.5406,  1.4097,  1.5205,  ..., -0.8308,  1.1799,  4.4395],
        [-0.4764,  1.4493,  1.8234,  ...,  1.57...259,  2.4132,  ...,  1.5033,  1.0380,  1.5486],
        [ 1.4862,  1.8833,  2.0412,  ...,  0.6160, -0.9966,  1.6773]]))
edge_attr = None, size = [None, None]

    def collect(
        self,
        edge_index: Union[Tensor, SparseTensor],
        x: OptPairTensor,
        edge_attr: OptTensor,
        size: List[Optional[int]],
    ) -> CollectArgs:

        i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)

        # Collect special arguments:
        if isinstance(edge_index, Tensor):
            if is_torch_sparse_tensor(edge_index):
                adj_t = edge_index
                if adj_t.layout == torch.sparse_coo:
                    edge_index_i = adj_t.indices()[0]
                    edge_index_j = adj_t.indices()[1]
                    ptr = None
                elif adj_t.layout == torch.sparse_csr:
                    ptr = adj_t.crow_indices()
                    edge_index_j = adj_t.col_indices()
                    edge_index_i = ptr2index(ptr, output_size=edge_index_j.numel())
                else:
                    raise ValueError(f"Received invalid layout '{adj_t.layout}'")
                if edge_attr is None:
                    _value = adj_t.values()
                    edge_attr = None if _value.dim() == 1 else _value

            else:
                edge_index_i = edge_index[i]
                edge_index_j = edge_index[j]
                ptr = None

        elif isinstance(edge_index, SparseTensor):
            adj_t = edge_index
            edge_index_i, edge_index_j, _value = adj_t.coo()
            ptr, _, _ = adj_t.csr()
            if edge_attr is None:
                edge_attr = None if _value is None or _value.dim() == 1 else _value

        else:
            raise NotImplementedError
>       assert edge_attr is not None
E       AssertionError

/tmp/torch_geometric.nn.conv.general_conv_GeneralConv_propagate_7jgj7uxt.py:78: AssertionError

Versions

Collecting environment information... PyTorch version: 2.3.1 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (aarch64) GCC version: Could not collect Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35

Python version: 3.11.9 (main, Apr 6 2024, 17:59:24) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-6.6.32-linuxkit-aarch64-with-glibc2.35 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: aarch64 CPU op-mode(s): 64-bit Byte Order: Little Endian CPU(s): 14 On-line CPU(s) list: 0-13 Vendor ID: Apple Model: 0 Thread(s) per core: 1 Core(s) per cluster: 14 Socket(s): - Cluster(s): 1 Stepping: 0x0 BogoMIPS: 48.00 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 asimddp sha512 asimdfhm dit uscat ilrcpc flagm ssbs sb paca pacg dcpodp flagm2 frint Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; __user pointer sanitization Vulnerability Spectre v2: Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] torch==2.3.1 [pip3] torch_geometric==2.5.3 [pip3] torchmetrics==1.4.0.post0 [conda] Could not collect

bgeier commented 2 months ago

I can get my test to pass if I include an arbitrary edge_attr to the layer and set in_edge_channels to 0. For example, the following revision is working. To get my desired behavior I need to pass in_edge_channels=0 and an arbitrary edge_attr tensor.

from torch_geometric.transforms import ToSparseTensor
from torch_geometric.nn import GeneralConv
import torch 

def mock_transform(data: Data) -> Data:

    data = ToSparseTensor(remove_edge_index=False)(data)
    data.edge_attr = torch.randint_like(
        torch.zeros(size=(data.edge_index.size()[1], 1)), low=0, high=1
    )

    return data

gnn = GeneralConv(in_channels=100, out_channels=100, in_edge_channels=0) 

dataset = FakeDataset(
    num_graphs=32 * 4,  # 4 batches of 32
    avg_num_nodes=20,
    num_channels=100,
    num_classes=2,
    edge_dim=1,
    is_undirected=False,
    transform=mock_transform,
)

gnn(dataset[0].x, dataset[0].adj_t, edge_attr=dataset[0].edge_attr)

setting in_edge_channels to 0 or 1 will get a passing test. Removing edge_attr=dataset[0].edge_attr with in_edge_channels as 0 or None raises the assertion.

This seems like a bug. edge_attr should be allowed to be None. I also don't like that I have to pass an edge_attr when the layer shouldn't use it (i.e., in_edge_channels=0).

rusty1s commented 1 month ago

I tested this here and it runs fine for us: https://github.com/pyg-team/pytorch_geometric/pull/9607. Not sure what causes this on your end.