dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.55k stars 3.01k forks source link

itemsampler using graphbolt ondiskdataset doesn't work with link prediction #7787

Closed Jaykim148 closed 2 months ago

Jaykim148 commented 2 months ago

🐛 Bug

I observed several bugs related to itemsampler with ondiskdataset

  1. unable to do negative sampling (sample_uniform_negative)
  2. unable to ingest graph with different number of nodes by node type
  3. sampled result doesn't look right without negative sampling
  • compacted_seed is always {edge_type: [[0, 0][1,1]]}
  • sampled blocks always have num_dst_nodes={nodetype: 2}
  • numbers of edges in blocks << numbers of edges in batches

To Reproduce

Steps to reproduce the behavior:

  1. generate nodes and edges and save them
  2. build graphbolt ondiskdataset
  3. use ItemSampler or DistributedItemSampler to sample edges and subgraphs

code sample:

# Install required packages.
import os
import torch
import numpy as np
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

device = torch.device("cuda")

import dgl
import dgl.graphbolt as gb

base_dir = '/test_dgl_on_disk/datasets/ondisk_dataset_heterograph'
os.makedirs(base_dir, exist_ok=True)

# For simplicity, we create a heterogeneous graph with
# 2 node types: `user`, `item`
# 2 edge types: `user:like:item`, `user:follow:user`
# And each node/edge type has the same number of nodes/edges.
user_num_nodes = 1000000
item_num_nodes = 1000000
num_edges = 4 * user_num_nodes

# Edge type: "user:like:item"
like_edges_path = os.path.join(base_dir, "like-edges.npy")
like_edges_user = np.random.randint(0, user_num_nodes, size=(num_edges))
like_edges_item = np.random.randint(0, item_num_nodes, size=(num_edges))
like_edges = np.stack((like_edges_user, like_edges_item), axis=0)#.astype(np.int32)
print(f"Part of [user:like:item] edges: {like_edges[:5]}\n")

np.save(like_edges_path, like_edges)
print(f"[user:like:item] edges are saved into {like_edges_path}\n")

# Edge type: "user:follow:user"
follow_edges_path = os.path.join(base_dir, "follow-edges.npy")
follow_edges = np.random.randint(0, user_num_nodes, size=(2, num_edges))#.astype(np.int32)
print(f"Part of [user:follow:user] edges: {follow_edges[:5]}\n")

np.save(follow_edges_path, follow_edges)
print(f"[user:follow:user] edges are saved into {follow_edges_path}\n")

# Train seeds for user:like:item.
lp_train_like_seeds_path = os.path.join(base_dir, "lp-train-like-seeds.npy")
lp_train_like_seeds = like_edges
print(f"Part of train seeds[user:like:item] for link prediction: {lp_train_like_seeds[:3]}")
np.save(lp_train_like_seeds_path, lp_train_like_seeds)
print(f"LP train seeds[user:like:item] are saved to {lp_train_like_seeds_path}\n")

yaml_content = f"""
    dataset_name: heterogeneous_graph_nc_lp
    graph:
      nodes:
        - type: user
          num: {user_num_nodes}
        - type: item
          num: {item_num_nodes}
      edges:
        - type: "user:like:item"
          format: numpy
          path: {os.path.basename(like_edges_path)}
        - type: "user:follow:user"
          format: numpy
          path: {os.path.basename(follow_edges_path)}
    tasks:
      - name: link_prediction
        num_classes: 10
        train_set:
          - type: "user:like:item"
            data:
              - name: seeds
                format: numpy
                path: {os.path.basename(lp_train_like_seeds_path)}
"""
metadata_path = os.path.join(base_dir, "metadata.yaml")
with open(metadata_path, "w") as f:
  f.write(yaml_content)

dataset = gb.OnDiskDataset(base_dir, force_preprocess=True).load()

item_set = dataset.tasks[0].train_set

datapipe = gb.ItemSampler(
    item_set, 
    1000, 
    shuffle=True, 
)
# datapipe = datapipe.sample_uniform_negative(dataset.graph, 5)
datapipe = datapipe.sample_neighbor(
    dataset.graph, 
    [10],
)
datapipe = datapipe.transform(gb.exclude_seed_edges)
dataloader = gb.DataLoader(datapipe)

for i, data in enumerate(dataloader):
    print(data)
    break

Error message with sample_uniform_negative

AssertionError: Only tensor with shape N*2 is supported for negative sampling, but got torch.Size([2, 4000000]).

Error from different numbers of nodes by node types (user_num_nodes = 10^6, item_num_nodes = 10^3)

AssertionError: The seed nodes should correspond to indptr.
This exception is thrown by __iter__ of CompactPerLayer(datapipe=SamplePerLayer, deduplicate=True)
File <command-66787977880156>, line 18
     16 import time
     17 start = time.time()
---> 18 for i, data in enumerate(dataloader):
     19     print(data)
     20     break
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py:181, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
    179         response = gen.send(None)
    180 else:
--> 181     response = gen.send(None)
    183 while True:
    184     datapipe._number_of_samples_yielded += 1
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/iter/callable.py:124, in MapperIterDataPipe.__iter__(self)
    123 def __iter__(self) -> Iterator[T_co]:
--> 124     for data in self.datapipe:
    125         yield self._apply_fn(data)
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py:181, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
    179         response = gen.send(None)
    180 else:
--> 181     response = gen.send(None)
    183 while True:
    184     datapipe._number_of_samples_yielded += 1
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-c7bf6781-b497-48ca-a644-d5130fad80dd/lib/python3.11/site-packages/dgl/graphbolt/base.py:385, in EndMarker.__iter__(self)
    384 def __iter__(self):
--> 385     yield from self.datapipe
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py:181, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
    179         response = gen.send(None)
    180 else:
--> 181     response = gen.send(None)
    183 while True:
    184     datapipe._number_of_samples_yielded += 1
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/iter/callable.py:124, in MapperIterDataPipe.__iter__(self)
    123 def __iter__(self) -> Iterator[T_co]:
--> 124     for data in self.datapipe:
    125         yield self._apply_fn(data)
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py:181, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
    179         response = gen.send(None)
    180 else:
--> 181     response = gen.send(None)
    183 while True:
    184     datapipe._number_of_samples_yielded += 1
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/iter/callable.py:124, in MapperIterDataPipe.__iter__(self)
    123 def __iter__(self) -> Iterator[T_co]:
--> 124     for data in self.datapipe:
    125         yield self._apply_fn(data)
    [... skipping similar frames: hook_iterator.<locals>.wrap_generator at line 181 (2 times), MapperIterDataPipe.__iter__ at line 124 (1 times)]
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/iter/callable.py:124, in MapperIterDataPipe.__iter__(self)
    123 def __iter__(self) -> Iterator[T_co]:
--> 124     for data in self.datapipe:
    125         yield self._apply_fn(data)
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/_hook_iterator.py:181, in hook_iterator.<locals>.wrap_generator(*args, **kwargs)
    179         response = gen.send(None)
    180 else:
--> 181     response = gen.send(None)
    183 while True:
    184     datapipe._number_of_samples_yielded += 1
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/iter/callable.py:125, in MapperIterDataPipe.__iter__(self)
    123 def __iter__(self) -> Iterator[T_co]:
    124     for data in self.datapipe:
--> 125         yield self._apply_fn(data)
File /databricks/python/lib/python3.11/site-packages/torch/utils/data/datapipes/iter/callable.py:90, in MapperIterDataPipe._apply_fn(self, data)
     88 def _apply_fn(self, data):
     89     if self.input_col is None and self.output_col is None:
---> 90         return self.fn(data)
     92     if self.input_col is None:
     93         res = self.fn(data)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-c7bf6781-b497-48ca-a644-d5130fad80dd/lib/python3.11/site-packages/dgl/graphbolt/minibatch_transformer.py:38, in MiniBatchTransformer._transformer(self, minibatch)
     37 def _transformer(self, minibatch):
---> 38     minibatch = self.transformer(minibatch)
     39     assert isinstance(
     40         minibatch, (MiniBatch,)
     41     ), "The transformer output should be an instance of MiniBatch"
     42     return minibatch
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-c7bf6781-b497-48ca-a644-d5130fad80dd/lib/python3.11/site-packages/dgl/graphbolt/impl/neighbor_sampler.py:474, in CompactPerLayer._compact_per_layer(self, minibatch)
    469 seeds = minibatch._seed_nodes
    470 if self.deduplicate:
    471     (
    472         original_row_node_ids,
    473         compacted_csc_format,
--> 474     ) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
    475     subgraph = SampledSubgraphImpl(
    476         sampled_csc=compacted_csc_format,
    477         original_column_node_ids=seeds,
    478         original_row_node_ids=original_row_node_ids,
    479         original_edge_ids=subgraph.original_edge_ids,
    480     )
    481 else:
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-c7bf6781-b497-48ca-a644-d5130fad80dd/lib/python3.11/site-packages/dgl/graphbolt/internal/sample_utils.py:199, in unique_and_compact_csc_formats(csc_formats, unique_dst_nodes, async_op)
    197         device = csc_format.indices.device
    198     src_type, _, dst_type = etype_str_to_tuple(etype)
--> 199     assert len(unique_dst_nodes.get(dst_type, [])) + 1 == len(
    200         csc_format.indptr
    201     ), "The seed nodes should correspond to indptr."
    202     indices[src_type].append(csc_format.indices)
    203 indices = {ntype: torch.cat(nodes) for ntype, nodes in indices.items()}

Generated minibatch result

MiniBatch(seeds={'user:like:item': tensor([[890374, 950250, 577883,  ..., 173522, 504136, 545573],
                        [ 13905, 977182,  56186,  ..., 395520, 218738, 729982]],
                       dtype=torch.int32)},
          sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'item:reverse-like:user': CSCFormatBase(indptr=tensor([ 0,  7, 10], dtype=torch.int32),
                                                                         indices=tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=torch.int32),
                                                           ), 'user:follow:user': CSCFormatBase(indptr=tensor([0, 0, 4], dtype=torch.int32),
                                                                         indices=tensor([2, 3, 4, 5], dtype=torch.int32),
                                                           ), 'user:like:item': CSCFormatBase(indptr=tensor([ 0,  6, 15], dtype=torch.int32),
                                                                         indices=tensor([ 6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
                                                                                        dtype=torch.int32),
                                                           )},
                                               original_row_node_ids={'user': tensor([890374,  13905, 729508, 703170, 619652, 516212, 965760, 836016, 925754,
                                                                             665361, 856709, 536708, 950250, 425713, 507552, 938355, 837611, 907104,
                                                                             661183, 609059, 975945], dtype=torch.int32), 'item': tensor([950250, 977182,  13905, 892309, 310559, 143541, 911025, 422485, 832509,
                                                                             768733, 877056, 566718], dtype=torch.int32)},
                                               original_edge_ids={'user:like:item': tensor([3801462, 3801463, 3801464, 3801465, 3801466, 3801467, 3908709, 3908710,
                                                                         3908711, 3908712, 3908713, 3908714, 3908715, 3908716, 3908717],
                                                                        dtype=torch.int32), 'item:reverse-like:user': tensor([11122797, 11122798, 11122799, 11122800, 11122801, 11122802, 11122803,
                                                                          4111878,  4111879,  4111880], dtype=torch.int32), 'user:follow:user': tensor([4111881, 4111882, 4111883, 4111884], dtype=torch.int32)},
                                               original_column_node_ids={'item': tensor([950250, 977182], dtype=torch.int32), 'user': tensor([890374,  13905], dtype=torch.int32)},
                            )],
          node_features=None,
          labels=None,
          input_nodes={'user': tensor([890374,  13905, 729508, 703170, 619652, 516212, 965760, 836016, 925754,
                              665361, 856709, 536708, 950250, 425713, 507552, 938355, 837611, 907104,
                              661183, 609059, 975945], dtype=torch.int32), 'item': tensor([950250, 977182,  13905, 892309, 310559, 143541, 911025, 422485, 832509,
                              768733, 877056, 566718], dtype=torch.int32)},
          indexes=None,
          edge_features=None,
          compacted_seeds={'user:like:item': tensor([[0, 0],
                                  [1, 1]], dtype=torch.int32)},
          blocks=[Block(num_src_nodes={'item': 12, 'user': 21},
                       num_dst_nodes={'item': 2, 'user': 2},
                       num_edges={('item', 'reverse-like', 'user'): 10, ('user', 'follow', 'user'): 4, ('user', 'like', 'item'): 15},
                       metagraph=[('item', 'user', 'reverse-like'), ('user', 'user', 'follow'), ('user', 'item', 'like')])],
       )

Expected behavior

Environment

Additional context

mfbalin commented 2 months ago

I was able to run your code with latest DGL and encountered the same issue after enabling negative sampling.

https://colab.research.google.com/drive/1ydHac4edafYsOl4dqIN05uirRnArc8I1?usp=sharing

frozenbugs commented 2 months ago

sample_uniform_negative requires the item sampler returns N*2 shaped tensor, and each row with 2 items as src and dst. While in your case:

seeds={'user:like:item': tensor([[890374, 950250, 577883,  ..., 173522, 504136, 545573],
                        [ 13905, 977182,  56186,  ..., 395520, 218738, 729982]],

seeds are 2 * N.

We did not fix the length of each row to 2 because there are use cases for hyperedge, which may have multiple nodes, that's why it could cause confusion.

Jaykim148 commented 2 months ago

I was able to run your code with latest DGL and encountered the same issue after enabling negative sampling.

sample_uniform_negative requires the item sampler returns N*2 shaped tensor, and each row with 2 items as src and dst. While in your case:

seeds={'user:like:item': tensor([[890374, 950250, 577883,  ..., 173522, 504136, 545573],
                        [ 13905, 977182,  56186,  ..., 395520, 218738, 729982]],

seeds are 2 * N.

We did not fix the length of each row to 2 because there are use cases for hyperedge, which may have multiple nodes, that's why it could cause confusion.

sample_uniform_negative requires the item sampler returns N*2 shaped tensor, and each row with 2 items as src and dst. While in your case:

seeds={'user:like:item': tensor([[890374, 950250, 577883,  ..., 173522, 504136, 545573],
                        [ 13905, 977182,  56186,  ..., 395520, 218738, 729982]],

seeds are 2 * N.

We did not fix the length of each row to 2 because there are use cases for hyperedge, which may have multiple nodes, that's why it could cause confusion.

The seeds come from OndiskDataset.task.train_set. I believe that for link prediction, the itemset from OndiskDataset should generate results that work with later pipeline stages. I also reported two other issues:

  1. When negative sampling was disabled, the sampled graphs only contained two edges.
  2. Sampling failed when node counts varied across different node types.

These problems might be related to the data shape from OndiskDataset.task.train_set (2*N instead of N*2).

Jaykim148 commented 2 months ago

I just confirmed that all the issues that I mentioned in this report were from input data shape. I converted the data shape from the OndiskDataset.task.train_set using the following code and works without any issue that I mentioned above.

item_set = gb.ItemSetDict(
    {key: gb.ItemSet((val._items[0].T, ), names=('seeds',)) for key, val in dataset.tasks[0].train_set._itemsets.items()})
mfbalin commented 2 months ago

I just confirmed that all the issues that I mentioned in this report were from input data shape. I converted the data shape from the OndiskDataset.task.train_set using the following code and works without any issue that I mentioned above.

item_set = gb.ItemSetDict(
    {key: gb.ItemSet((val._items[0].T, ), names=('seeds',)) for key, val in dataset.tasks[0].train_set._itemsets.items()})

A better way to resolve could be to save the edges that go into the itemset in the transposed way separately.

Jaykim148 commented 2 months ago

A better way to resolve could be to save the edges that go into the itemset in the transposed way separately.

There is no way to change the numpy array shape from the beginning. The numpy array should have (2, N) shape based on this doc (https://docs.dgl.ai/en/2.1.x/stochastic_training/ondisk_dataset_heterograph.html)

mfbalin commented 2 months ago

A better way to resolve could be to save the edges that go into the itemset in the transposed way separately.

There is no way to change the numpy array shape from the beginning. The numpy array should have (2, N) shape based on this doc (https://docs.dgl.ai/en/2.1.x/stochastic_training/ondisk_dataset_heterograph.html)

See the updated notebook: https://colab.research.google.com/drive/1ydHac4edafYsOl4dqIN05uirRnArc8I1?usp=sharing

Jaykim148 commented 2 months ago

There is no way to change the numpy array shape from the beginning. The numpy array should have (2, N) shape based on this doc (https://docs.dgl.ai/en/2.1.x/stochastic_training/ondisk_dataset_heterograph.html)

Oh, I got it. I thought that all the edges should be (2, N) shape including the training set. But, (2, N) shape is required only for defining graph edges. Thank you for debugging.