barkincavdaroglu / Link-Prediction-Mesh-Network

PyTorch Implementation of a Deep Learning Model for Temporal Link Prediction in MANETs
2 stars 1 forks source link

Vectorized neighborhood padding for sequential aggregation #35

Closed barkincavdaroglu closed 1 year ago

barkincavdaroglu commented 1 year ago

Closes #33

lengths = torch_scatter.scatter_add(
    torch.ones(edges.shape[1]),
    edges[0],
    dim=0,
    dim_size=node_fts[0],
)
max_neighborhood_size = int(lengths.max().item())

neighbs, _ = to_dense_batch(
    node_fts[edges[1]], edges[0], 0, max_neighborhood_size
)

edg, _ = to_dense_batch(edge_fts, edges[0], 0, max_neighborhood_size)

sorted_edg = torch.argsort(edg, dim=1, descending=True)
sorted_edg = sorted_edg.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2, 12)

res = torch.zeros_like(neighbs)
torch.gather(neighbs, 1, sorted_edg, out=res)

packed = nn.utils.rnn.pack_padded_sequence(
    res, lengths.to("cpu"), batch_first=True, enforce_sorted=False
)