barkincavdaroglu / Link-Prediction-Mesh-Network

PyTorch Implementation of a Deep Learning Model for Temporal Link Prediction in MANETs
2 stars 1 forks source link

Batch support #15

Closed barkincavdaroglu closed 1 year ago

barkincavdaroglu commented 1 year ago

Closes #12 Implements collate_fn to handle batching of data, since PyTorch's default collate_fn won't work on variable sized graph data.

collate_fn returns a list [batched_edges, batched_node_fts, batched_edge_fts, batched_graph_fts, batched_target] where:

      batched_edges has dim (2, total number of edges in the batch),       batched_node_fts has dim (total node number in the batch, node feature dim)       batched_edge_fts has dim (total number of edges in the batch, edge feature dim)       batched_graph_fts has dim (total number of graphs in the batch, graph feature dim)       batched_target has dim (total number of nodes in the batch, max number of nodes in any single graph in the batch)

Since the task of the model is to predict the adjacency matrix, we can only train on graphs of fixed node size (with varying edge numbers). However, in the future, padding to max number of nodes can be done to fix this issue.

Note: Loading all pickle files at init will create some overhead, but since we enable batching now training becomes much faster and the overhead is negligible. Also, passing batches of filenames doesn't make sense.