Open XGZepto opened 2 months ago
Hi! Have you tried generating an index list for the nodes in HeteroData
and use the index mapping list corresponding to 'someType' as train_idx
in DistributedSampler
?
Hi! Have you tried generating an index list for the nodes in
HeteroData
and use the index mapping list corresponding to 'someType' astrain_idx
inDistributedSampler
?
Since it's heterodata a node index only makes sense when it's coupled with the node type. I guess for now I'll need to manually feed different rank with different indices of nodes.
Edit: turned out you can directly give DistributedSampler torch.arange(0, num_someType_nodes, dtype=torch.long)
and it works as expected. Cannot find document on how PyG's implementation handles custom sampler
(not the node_sampler
they added to their derivation of DataLoader) so this came as a surprise.
The only issue remains is that epoch runtime is now 5x compared to my naive implementation of feeding NeighborLoader in different rank with different input_node index. If anyone can verify that it's a Torch issuse I guess we can close this one.
In our PyG example,
train_sampler
was defined like this:with train_idx being the dataset provided.
However, in my current project where we are handling
HeteroData
, we configuredNeighborLoader
not with atrain_idx
, rather with specific node type:I cannot figure out a way to directly translate the example into our use case.
How could we handle
HeteroData
?