lightaime / deep_gcns_torch

Pytorch Repo for DeepGCNs (ICCV'2019 Oral, TPAMI'2021), DeeperGCN (arXiv'2020) and GNN1000(ICML'2021): https://www.deepgcns.org
MIT License
1.13k stars 155 forks source link

DilatedKnnGraph batch size in "matrix" mode #95

Open zademn opened 2 years ago

zademn commented 2 years ago

DilatedKnnGraph can accept a badly shaped batch parameter when using "matrix" mode. This happens because only batch_size is used. Maybe an assertion / error should be raised as in Pytorch geometric's knn_graph.

import torch
from torch_geometric.nn import knn_graph

p = torch.rand((256, 3))
t = torch.cat([p, p])
batch = torch.cat([torch.ones(data_cloud.shape[0]) * i for i in range(2)]).type(torch.long) # normal batch 
batch2 = torch.tensor([0, 0, 1,]) 
batch3 = torch.tensor([0, 1, 0,])
dknn = DilatedKnnGraph(k = 3, dilation =  1)

f0 = knn_graph(t, k = 3, batch = batch, loop = True) 

f1 = dknn(t, batch = batch2) # Maybe this should raise a shape error?
torch.all(f1 == f0) # True

f2 = dknn(t, batch = batch3) # Weird behaviour since [-1] is used to compute batch_size
torch.all(f2 == f0) # False 

Also at https://github.com/lightaime/deep_gcns_torch/blob/751382aa2d25e25a2792c133cc99f8cfddae0657/gcn_lib/sparse/torch_edge.py#L78 Maybe batch_size should be calculated using .max()

Maybe future issue (but we can assume this is a misuse from the user I guess): If the user is not passing the batch parameter in order (for example passes [0, 1, 0, 1] instead of [0, 0, 1,1]) I'm not sure the reshape part will work.