Open bgeier opened 2 months ago
I can get my test to pass if I include an arbitrary edge_attr to the layer and set in_edge_channels
to 0. For example, the following revision is working. To get my desired behavior I need to pass in_edge_channels=0
and an arbitrary edge_attr
tensor.
from torch_geometric.transforms import ToSparseTensor
from torch_geometric.nn import GeneralConv
import torch
def mock_transform(data: Data) -> Data:
data = ToSparseTensor(remove_edge_index=False)(data)
data.edge_attr = torch.randint_like(
torch.zeros(size=(data.edge_index.size()[1], 1)), low=0, high=1
)
return data
gnn = GeneralConv(in_channels=100, out_channels=100, in_edge_channels=0)
dataset = FakeDataset(
num_graphs=32 * 4, # 4 batches of 32
avg_num_nodes=20,
num_channels=100,
num_classes=2,
edge_dim=1,
is_undirected=False,
transform=mock_transform,
)
gnn(dataset[0].x, dataset[0].adj_t, edge_attr=dataset[0].edge_attr)
setting in_edge_channels
to 0 or 1 will get a passing test. Removing edge_attr=dataset[0].edge_attr
with in_edge_channels
as 0 or None
raises the assertion.
This seems like a bug. edge_attr
should be allowed to be None. I also don't like that I have to pass an edge_attr when the layer shouldn't use it (i.e., in_edge_channels=0
).
I tested this here and it runs fine for us: https://github.com/pyg-team/pytorch_geometric/pull/9607. Not sure what causes this on your end.
🐛 Describe the bug
The
GeneralConv
layer is raising an assertion when only a node array (x
) and adjacency (edge_index
) are provided. I expect the layer to return a result when I don't provideedge_attr
(default is None).Context: I'm writing unit tests for a model that is composed of many layers. I've worked back to a core torch_geometric layer that is raising an assert. I get the assertion when providing a basic data input. For example,
The forward pass to
gnn
is raising an AssertionError viaassert edge_attr is not None
. I'm having trouble locating the file asserting the error. Is this a legit bug or user error? Any help would be much appreciated! Thanks!Here's the traceback from the above test.
Versions
Collecting environment information... PyTorch version: 2.3.1 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (aarch64) GCC version: Could not collect Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35
Python version: 3.11.9 (main, Apr 6 2024, 17:59:24) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-6.6.32-linuxkit-aarch64-with-glibc2.35 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: aarch64 CPU op-mode(s): 64-bit Byte Order: Little Endian CPU(s): 14 On-line CPU(s) list: 0-13 Vendor ID: Apple Model: 0 Thread(s) per core: 1 Core(s) per cluster: 14 Socket(s): - Cluster(s): 1 Stepping: 0x0 BogoMIPS: 48.00 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 asimddp sha512 asimdfhm dit uscat ilrcpc flagm ssbs sb paca pacg dcpodp flagm2 frint Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; __user pointer sanitization Vulnerability Spectre v2: Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected
Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] torch==2.3.1 [pip3] torch_geometric==2.5.3 [pip3] torchmetrics==1.4.0.post0 [conda] Could not collect