jasonlin316 / ARGO

Official implementation of ARGO, an auto-tuning runtime system for GNN training on CPU.
3 stars 1 forks source link

Support for HeteroData #7

Open XGZepto opened 2 months ago

XGZepto commented 2 months ago

In our PyG example, train_sampler was defined like this:

    train_sampler = DistributedSampler(
            train_idx,
            num_replicas = world_size,
            rank=rank
        ) 

with train_idx being the dataset provided.

However, in my current project where we are handling HeteroData, we configured NeighborLoader not with a train_idx, rather with specific node type:

    train_loader = NeighborLoader(
            ...,
            input_nodes = 'someType',
            ...
        ) 

I cannot figure out a way to directly translate the example into our use case.

How could we handle HeteroData?

chen-yy20 commented 2 months ago

Hi! Have you tried generating an index list for the nodes in HeteroData and use the index mapping list corresponding to 'someType' as train_idx in DistributedSampler?

XGZepto commented 2 months ago

Hi! Have you tried generating an index list for the nodes in HeteroData and use the index mapping list corresponding to 'someType' as train_idx in DistributedSampler?

Since it's heterodata a node index only makes sense when it's coupled with the node type. I guess for now I'll need to manually feed different rank with different indices of nodes.

Edit: turned out you can directly give DistributedSampler torch.arange(0, num_someType_nodes, dtype=torch.long) and it works as expected. Cannot find document on how PyG's implementation handles custom sampler (not the node_sampler they added to their derivation of DataLoader) so this came as a surprise.

The only issue remains is that epoch runtime is now 5x compared to my naive implementation of feeding NeighborLoader in different rank with different input_node index. If anyone can verify that it's a Torch issuse I guess we can close this one.