barkincavdaroglu / Link-Prediction-Mesh-Network

PyTorch Implementation of a Deep Learning Model for Temporal Link Prediction in MANETs
2 stars 1 forks source link

Vectorize Neighborhood Padding for LSTM Aggregation #33

Closed barkincavdaroglu closed 1 year ago

barkincavdaroglu commented 1 year ago

Currently, we construct padded neighborhoods of nodes sorted by edge weights to pass into LSTM as follows:

seqs, lenghts = [], []
for node_idx in range(node_fts.shape[0]):
    seqs.append(
        dst_node_fts[
            torch.argsort(
                edge_fts[
                    :,
                ][edges[0] == node_idx],
                descending=True,
            )
        ]
    )
    lenghts.append(len(seqs[-1]))
lenghts = torch.tensor(lenghts)
seqs = pad_sequence(seqs, batch_first=True)
packed = nn.utils.rnn.pack_padded_sequence(
    seqs, lenghts.to("cpu"), batch_first=True, enforce_sorted=False
)
_, (dst_node_fts_neigh_agg_final, _) = self.rnn(packed)

This is incredibly inefficient since we are using for-loops. Number of nodes per graph is 207, and our batch size is usually either 64 or 128, hence we loop for either 13248 or 26496 times. It is possible to form padded neighborhoods using:

from torch_geometric.utils import to_dense_batch

neighbs, mask = to_dense_batch(batch.x[a.edge_index[1]], batch.edge_index[0], 0, max_neigh_size)

But it doesn't sort using edge weights. A quick play-around:

neighbs, mask = to_dense_batch(a.x[a.edge_index[1]], a.edge_index[0], 0, 18)

edg, m = to_dense_batch(a.edge_attr, a.edge_index[0], 0, 18)

sorted_edg = torch.argsort(edg, dim=1, descending=True)

sorted_edg = sorted_edg.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2, 12)

res = torch.zeros_like(neighbs).scatter_(1, sorted_edg, neighbs)

However there is a bug with this approach that I cannot figure out.

tensor([[[ 0.4729,  0.3832,  0.3155,  0.1908,  0.3383,  0.3440,  0.3319,  0.4024,  0.1138,  0.0077,  0.2421,  0.1138],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.6781,  0.6332,  0.7088,  0.5948,  0.7038,  0.6233,  0.5755,  0.5819,  0.6204,  0.4409,  0.6268,  0.6076],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.7038,  0.4216,  0.7031,  0.6396,  0.6204,  0.4751,  0.5370,  0.5691,  0.5499,  0.6404,  0.4537,  0.4665],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.6396,  0.6396,  0.7202,  0.6268,  0.6909,  0.5834,  0.6525,  0.6012,  0.6589,  0.2927,  0.6140,  0.6140],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.5691,  0.5755,  0.2870,  0.3383,  0.5819,  0.4409,  0.4922,  0.4986,  0.3254,  0.2243,  0.4024,  0.3062],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.2549,  0.2228,  0.5150,  0.4857,  0.5755,  0.5264,  0.5370,  0.6396,  0.4088, -1.0183, -2.0856, -2.3613],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.3383,  0.6845,  0.4580,  0.4922,  0.4216,  0.2699,  0.4280,  0.6204,  0.4280,  0.3269,  0.4345,  0.0433],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.1908,  0.3447,  0.2642,  0.2549,  0.2549,  0.2129,  0.5306,  0.3236,  0.2357,  0.1730,  0.1908,  0.1715],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.6461,  0.6461,  0.5720,  0.4152,  0.4986,  0.5948,  0.4986,  0.6268,  0.5819,  0.4409,  0.5883,  0.4216],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.5819,  0.7487,  0.4922,  0.3703,  0.3575,  0.6347,  0.6653,  0.5755,  0.4409,  0.5663,  0.5691,  0.3832],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.7422,  0.3896,  0.7373,  0.5178,  0.6781,  0.7259,  0.5883,  0.6461,  0.6140,  0.5378,  0.5691,  0.6781],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.6461,  0.6653,  0.4295,  0.5050,  0.5499,  0.5378,  0.5435,  0.5691,  0.3896,  0.5036,  0.3960,  0.4216],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.4601,  0.1074,  0.4124,  0.3896,  0.2293,  0.3212, -0.1298,  0.0754,  0.2228,  0.1160,  0.4024,  0.1523],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.5948,  0.7047,  0.6404,  0.4857,  0.6167,  0.6974,  0.6781,  0.7230,  0.6012,  0.4580,  0.6845,  0.5435],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]], dtype=torch.float64)
tensor([1.0000, 0.6337, 0.1193, 0.5226, 0.7409, 0.2701, 0.5134, 0.1940, 0.1615, 0.2635, 0.4548, 0.9461, 0.1531, 0.3762, 0.0000, 0.0000, 0.0000, 0.0000])
tensor([ 0, 11,  4,  1,  3,  6, 10, 13,  5,  9,  7,  8, 12,  2, 14, 15, 16, 17])
tensor([[[ 0.4729,  0.3832,  0.3155,  0.1908,  0.3383,  0.3440,  0.3319,  0.4024,  0.1138,  0.0077,  0.2421,  0.1138],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.6396,  0.6396,  0.7202,  0.6268,  0.6909,  0.5834,  0.6525,  0.6012,  0.6589,  0.2927,  0.6140,  0.6140],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.5948,  0.7047,  0.6404,  0.4857,  0.6167,  0.6974,  0.6781,  0.7230,  0.6012,  0.4580,  0.6845,  0.5435],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.5691,  0.5755,  0.2870,  0.3383,  0.5819,  0.4409,  0.4922,  0.4986,  0.3254,  0.2243,  0.4024,  0.3062],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.7038,  0.4216,  0.7031,  0.6396,  0.6204,  0.4751,  0.5370,  0.5691,  0.5499,  0.6404,  0.4537,  0.4665],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.6461,  0.6461,  0.5720,  0.4152,  0.4986,  0.5948,  0.4986,  0.6268,  0.5819,  0.4409,  0.5883,  0.4216],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.2549,  0.2228,  0.5150,  0.4857,  0.5755,  0.5264,  0.5370,  0.6396,  0.4088, -1.0183, -2.0856, -2.3613],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.7422,  0.3896,  0.7373,  0.5178,  0.6781,  0.7259,  0.5883,  0.6461,  0.6140,  0.5378,  0.5691,  0.6781],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.6461,  0.6653,  0.4295,  0.5050,  0.5499,  0.5378,  0.5435,  0.5691,  0.3896,  0.5036,  0.3960,  0.4216],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.5819,  0.7487,  0.4922,  0.3703,  0.3575,  0.6347,  0.6653,  0.5755,  0.4409,  0.5663,  0.5691,  0.3832],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.3383,  0.6845,  0.4580,  0.4922,  0.4216,  0.2699,  0.4280,  0.6204,  0.4280,  0.3269,  0.4345,  0.0433],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.6781,  0.6332,  0.7088,  0.5948,  0.7038,  0.6233,  0.5755,  0.5819,  0.6204,  0.4409,  0.6268,  0.6076],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.4601,  0.1074,  0.4124,  0.3896,  0.2293,  0.3212, -0.1298,  0.0754,  0.2228,  0.1160,  0.4024,  0.1523],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.1908,  0.3447,  0.2642,  0.2549,  0.2549,  0.2129,  0.5306,  0.3236,  0.2357,  0.1730,  0.1908,  0.1715],
         [ 0.0521,  0.0556,  0.0590,  0.0625,  0.0660,  0.0694,  0.0729,  0.0764,  0.0799,  0.0833,  0.0868,  0.0903]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]], dtype=torch.float64)
barkincavdaroglu commented 1 year ago

Solved. Problem is that DataLoader is returning data in random order. Not exactly true. While shuffle was true (and data was being returned in random order everytime), the problem was with using tensor.scatter_. Resolved with torch.gather.