pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.54k stars 3.57k forks source link

Batch sampling of `edge_attr` has different shape than `edge_index` #9321

Open jykr opened 1 month ago

jykr commented 1 month ago

During the batch sampling, my edge_attr is not subsampled consistently with edge_index, am I missing something here?

Here's my data:

HeteroData(
  source={
    x=[2700, 50],
    size_factor=[2700],
    cont_cov=[2700],
    n_id=[2700],
  },
  target={
    size_factor=[1738],
    x=[1738, 50],
    n_id=[1738],
  },
  (source, rel, target)={
    edge_index=[2, 541461],
    edge_attr=[541461],
    edge_dist='Normal',
  },
  (target, rev_rel, source)={
    edge_index=[2, 541461],
    edge_attr=[541461],
  }
)

Loader:

NeighborLoader(
        train_data,
        num_neighbors={key: [30] * 2 for key in bidirectional_data.edge_types},
        batch_size=128,
        input_nodes=nodetype,
    )

Batch:

batch_bd HeteroData(
  source={
    x=[2699, 50],
    size_factor=[2699],
    cont_cov=[2699],
    n_id=[2699],
    num_sampled_nodes=[3],
    input_id=[128],
    batch_size=128,
  },
  target={
    size_factor=[1187],
    x=[1187, 50],
    n_id=[1187],
    num_sampled_nodes=[3],
  },
  (source, rel, target)={
    edge_index=[2, 35345],
    edge_dist='Normal',
    edge_attr=[146194],
    edge_attr_index=[2, 146194],
    e_id=[35345],
    num_sampled_edges=[2],
  },
  (target, rev_rel, source)={
    edge_index=[2, 3840],
    edge_attr=[146194],
    edge_attr_index=[2, 146194],
    e_id=[3840],
    num_sampled_edges=[2],
  }

_Originally posted by @jykr in https://github.com/pyg-team/pytorch_geometric/discussions/9097#discussioncomment-9435545_

rusty1s commented 1 month ago

Answered in https://github.com/pyg-team/pytorch_geometric/discussions/9097#discussioncomment-9522239. Let's keep the discussion in a single place.