DilatedKnnGraph can accept a badly shaped batch parameter when using "matrix" mode. This happens because only batch_size is used. Maybe an assertion / error should be raised as in Pytorch geometric's knn_graph.
import torch
from torch_geometric.nn import knn_graph
p = torch.rand((256, 3))
t = torch.cat([p, p])
batch = torch.cat([torch.ones(data_cloud.shape[0]) * i for i in range(2)]).type(torch.long) # normal batch
batch2 = torch.tensor([0, 0, 1,])
batch3 = torch.tensor([0, 1, 0,])
dknn = DilatedKnnGraph(k = 3, dilation = 1)
f0 = knn_graph(t, k = 3, batch = batch, loop = True)
f1 = dknn(t, batch = batch2) # Maybe this should raise a shape error?
torch.all(f1 == f0) # True
f2 = dknn(t, batch = batch3) # Weird behaviour since [-1] is used to compute batch_size
torch.all(f2 == f0) # False
Maybe future issue (but we can assume this is a misuse from the user I guess):
If the user is not passing the batch parameter in order (for example passes [0, 1, 0, 1] instead of [0, 0, 1,1]) I'm not sure the reshape part will work.
DilatedKnnGraph can accept a badly shaped batch parameter when using "matrix" mode. This happens because only batch_size is used. Maybe an assertion / error should be raised as in Pytorch geometric's knn_graph.
Also at https://github.com/lightaime/deep_gcns_torch/blob/751382aa2d25e25a2792c133cc99f8cfddae0657/gcn_lib/sparse/torch_edge.py#L78 Maybe batch_size should be calculated using
.max()
Maybe future issue (but we can assume this is a misuse from the user I guess): If the user is not passing the batch parameter in order (for example passes [0, 1, 0, 1] instead of [0, 0, 1,1]) I'm not sure the reshape part will work.