rusty1s / pyg_autoscale

Implementation of "GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings" in PyTorch
http://arxiv.org/abs/2106.05609
MIT License
159 stars 27 forks source link

SubgraphLoader for heterogeneous graph #17

Open Chen-Cai-OSU opened 2 years ago

Chen-Cai-OSU commented 2 years ago

Hi, I am trying to apply pyg_autoscale to heterogeneous graph and have to modify the compute_subgraph method in SubgraphLoader class. I was wondering would you like to elaborate on what offset, count are and what is relabel_fn doing? My current understanding is that compute_subgraph is basically taking the sub-graph spanned by n_id. Is this understanding accurate? Many thanks!

    def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData:
        batch_ids, n_ids = zip(*batches)
        n_id = torch.cat(n_ids, dim=0)
        batch_id = torch.tensor(batch_ids)

        # We collect the in-mini-batch size (`batch_size`), the offset of each
        # partition in the mini-batch (`offset`), and the number of nodes in
        # each partition (`count`)
        batch_size = n_id.numel()
        offset = self.ptr[batch_id]
        count = self.ptr[batch_id.add_(1)].sub_(offset)

        rowptr, col, value = self.data.adj_t.csr()
        rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
                                              self.bipartite)

        adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
                             sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
                             is_sorted=True)

        data = self.data.__class__(adj_t=adj_t)
        for k, v in self.data:
            if isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
                data[k] = v.index_select(0, n_id)

        return SubData(data, batch_size, n_id, offset, count)
rusty1s commented 2 years ago

Yes, that is correct. Importantly, batches denotes a list of contiguous node indices grouped that we want to group into one single mini-batch/subgraph, for example: [[0, 1, 2], [5, 6, 7], [10, 11, 12, 13]] for which offset would be [0, 5, 10] and count would be [3, 3, 4]. relabel_fn then computes the induced subgraph of these chunks of nodes, and relabels their node indices to [0, ..., 9].