Closed tvercaut closed 1 month ago
For the record, the initial workaround
def my_gather_mm(a, b, idx_b):
# mimic https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
R,D1,D2 = b.shape
N = idx_b.shape[0]
# Sanity check sizes
assert(a.shape[0]==N and a.shape[1]==D1)
torchdevice = a.device
src_idx = torch.arange(N,device=torchdevice)
# Ideally the conversions below to nested tensor would be handled without for looops and without copy
nested_a = torch.nested.as_nested_tensor(
[torch.index_select(a,dim=0,index=torch.nonzero(idx_b==i).squeeze()) for i in range(R)] )
src_idx_reshuffled = torch.cat(
[torch.index_select(src_idx,dim=0,index=torch.nonzero(idx_b==i).squeeze()) for i in range(R)] )
nested_b = torch.nested.as_nested_tensor(
[b[i,:,:].squeeze() for i in range(R)] )
# The actual gather matmul computation
nested_ab = torch.matmul(nested_a,nested_b)
# Convert back to tensors, again, ideally this would be handled natively with no copy
ab_segmented = torch.cat(nested_ab.unbind(),dim=0)
ab = torch.empty((N,D2),device=torchdevice)
ab[src_idx_reshuffled] = ab_segmented
return ab
can be simplified a bit
def my_gather_mm(a, b, idx_b):
# mimic https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
R,D1,D2 = b.shape
N = idx_b.shape[0]
# Sanity check sizes
assert(a.shape[0]==N and a.shape[1]==D1)
torchdevice = a.device
src_idx = torch.arange(N,device=torchdevice)
# Ideally the conversions below to nested tensor would be handled without for looops and without copy
nested_a = torch.nested.as_nested_tensor([a[idx_b==i,:] for i in range(R)] )
src_idx_reshuffled = torch.cat( [src_idx[idx_b==i] for i in range(R)] )
nested_b = torch.nested.as_nested_tensor(
[b[i,:,:].squeeze() for i in range(R)] )
# The actual gather matmul computation
nested_ab = torch.matmul(nested_a,nested_b)
# Convert back to tensors, again, ideally this would be handled natively with no copy
ab_segmented = torch.cat(nested_ab.unbind(),dim=0)
ab = torch.empty((N,D2),device=torchdevice)
ab[src_idx_reshuffled] = ab_segmented
return ab
segment_mm
andgather_mm
are borderline in scope for this repo but they would still be useful additions. DGL provides some nice functions but installing it is non trivial. https://docs.dgl.ai/generated/dgl.ops.gather_mm.html https://docs.dgl.ai/generated/dgl.ops.segment_mm.html https://pyg-lib.readthedocs.io/en/latest/modules/ops.html#pyg_lib.ops.segment_matmulA workaround for
gather_mm
is found here and would be a good starting point for inclusion in torchsparsegradutils (just needs a unit test): https://github.com/pytorch/pytorch/issues/136747