cai4cai / torchsparsegradutils

A collection of utility functions to work with PyTorch sparse tensors
Apache License 2.0
24 stars 3 forks source link

Consider adding a pure PyTorch function for `segment_mm` and `gather_mm` #56

Closed tvercaut closed 1 month ago

tvercaut commented 1 month ago

segment_mm and gather_mm are borderline in scope for this repo but they would still be useful additions. DGL provides some nice functions but installing it is non trivial. https://docs.dgl.ai/generated/dgl.ops.gather_mm.html https://docs.dgl.ai/generated/dgl.ops.segment_mm.html https://pyg-lib.readthedocs.io/en/latest/modules/ops.html#pyg_lib.ops.segment_matmul

A workaround for gather_mm is found here and would be a good starting point for inclusion in torchsparsegradutils (just needs a unit test): https://github.com/pytorch/pytorch/issues/136747

tvercaut commented 1 month ago

For the record, the initial workaround

def my_gather_mm(a, b, idx_b):
  # mimic https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
  R,D1,D2 = b.shape
  N = idx_b.shape[0]

  # Sanity check sizes
  assert(a.shape[0]==N and a.shape[1]==D1)

  torchdevice = a.device
  src_idx = torch.arange(N,device=torchdevice)

  # Ideally the conversions below to nested tensor would be handled without for looops and without copy
  nested_a = torch.nested.as_nested_tensor( 
      [torch.index_select(a,dim=0,index=torch.nonzero(idx_b==i).squeeze()) for i in range(R)] )
  src_idx_reshuffled = torch.cat( 
      [torch.index_select(src_idx,dim=0,index=torch.nonzero(idx_b==i).squeeze()) for i in range(R)] )
  nested_b = torch.nested.as_nested_tensor( 
      [b[i,:,:].squeeze() for i in range(R)] )

  # The actual gather matmul computation
  nested_ab = torch.matmul(nested_a,nested_b)

  # Convert back to tensors, again, ideally this would be handled natively with no copy
  ab_segmented = torch.cat(nested_ab.unbind(),dim=0)
  ab = torch.empty((N,D2),device=torchdevice)
  ab[src_idx_reshuffled] = ab_segmented
  return ab

can be simplified a bit

def my_gather_mm(a, b, idx_b):
  # mimic https://docs.dgl.ai/generated/dgl.ops.gather_mm.html
  R,D1,D2 = b.shape
  N = idx_b.shape[0]

  # Sanity check sizes
  assert(a.shape[0]==N and a.shape[1]==D1)

  torchdevice = a.device
  src_idx = torch.arange(N,device=torchdevice)

  # Ideally the conversions below to nested tensor would be handled without for looops and without copy
  nested_a = torch.nested.as_nested_tensor([a[idx_b==i,:] for i in range(R)] )
  src_idx_reshuffled = torch.cat( [src_idx[idx_b==i] for i in range(R)] )
  nested_b = torch.nested.as_nested_tensor(
      [b[i,:,:].squeeze() for i in range(R)] )

  # The actual gather matmul computation
  nested_ab = torch.matmul(nested_a,nested_b)

  # Convert back to tensors, again, ideally this would be handled natively with no copy
  ab_segmented = torch.cat(nested_ab.unbind(),dim=0)
  ab = torch.empty((N,D2),device=torchdevice)
  ab[src_idx_reshuffled] = ab_segmented
  return ab