Closed barkincavdaroglu closed 1 year ago
Closes #33
lengths = torch_scatter.scatter_add( torch.ones(edges.shape[1]), edges[0], dim=0, dim_size=node_fts[0], ) max_neighborhood_size = int(lengths.max().item()) neighbs, _ = to_dense_batch( node_fts[edges[1]], edges[0], 0, max_neighborhood_size ) edg, _ = to_dense_batch(edge_fts, edges[0], 0, max_neighborhood_size) sorted_edg = torch.argsort(edg, dim=1, descending=True) sorted_edg = sorted_edg.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2, 12) res = torch.zeros_like(neighbs) torch.gather(neighbs, 1, sorted_edg, out=res) packed = nn.utils.rnn.pack_padded_sequence( res, lengths.to("cpu"), batch_first=True, enforce_sorted=False )
Closes #33