dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.2k stars 2.99k forks source link

[GraphBolt] Temporal link prediction example crash #7505

Open mfbalin opened 6 days ago

mfbalin commented 6 days ago

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. python temporal_link_prediction.py
Training in cpu-cuda mode.
Loading data
Downloading datasets/diginetica-r2ne.zip from https://dgl-data.s3-accelerate.amazonaws.com/dataset/diginetica-r2ne.zip...
datasets/diginetica-r2ne.zip: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 404M/404M [00:09<00:00, 40.9MB/s]
Extracting file to datasets
Start to preprocess the on-disk dataset.
/localscratch/dgl-3/python/dgl/graphbolt/impl/ondisk_dataset.py:460: DGLWarning: Edge feature is stored, but edge IDs are not saved.
  dgl_warning("Edge feature is stored, but edge IDs are not saved.")
Finish preprocessing the on-disk dataset.
Training...
0it [00:00, ?it/s]/localscratch/dgl-3/python/dgl/graphbolt/item_sampler.py:94: DGLWarning: Unknown item name 'node_pairs' is detected and added into `MiniBatch`. You probably need to provide a customized `MiniBatcher`.
  dgl_warning(
/localscratch/dgl-3/python/dgl/graphbolt/item_sampler.py:94: DGLWarning: Unknown item name 'YEAR(timestamp)' is detected and added into `MiniBatch`. You probably need to provide a customized `MiniBatcher`.
  dgl_warning(
/localscratch/dgl-3/python/dgl/graphbolt/item_sampler.py:94: DGLWarning: Unknown item name 'MONTH(timestamp)' is detected and added into `MiniBatch`. You probably need to provide a customized `MiniBatcher`.
  dgl_warning(
/localscratch/dgl-3/python/dgl/graphbolt/item_sampler.py:94: DGLWarning: Unknown item name 'DAY(timestamp)' is detected and added into `MiniBatch`. You probably need to provide a customized `MiniBatcher`.
  dgl_warning(
/localscratch/dgl-3/python/dgl/graphbolt/item_sampler.py:94: DGLWarning: Unknown item name 'DAYOFWEEK(timestamp)' is detected and added into `MiniBatch`. You probably need to provide a customized `MiniBatcher`.
  dgl_warning(
/localscratch/dgl-3/python/dgl/graphbolt/item_sampler.py:94: DGLWarning: Unknown item name 'TIMESTAMP(timestamp)' is detected and added into `MiniBatch`. You probably need to provide a customized `MiniBatcher`.
  dgl_warning(
/localscratch/dgl-3/python/dgl/graphbolt/item_sampler.py:94: DGLWarning: Unknown item name 'timestamp' is detected and added into `MiniBatch`. You probably need to provide a customized `MiniBatcher`.
  dgl_warning(
0it [00:00, ?it/s]
Traceback (most recent call last):
  File "/localscratch/dgl-3/examples/sampling/graphbolt/pyg/labor/../../temporal_link_prediction.py", line 322, in <module>
    main(args)
  File "/localscratch/dgl-3/examples/sampling/graphbolt/pyg/labor/../../temporal_link_prediction.py", line 317, in main
    train(args, model, graph, features, train_set, encoders)
  File "/localscratch/dgl-3/examples/sampling/graphbolt/pyg/labor/../../temporal_link_prediction.py", line 169, in train
    for step, data in tqdm.tqdm(enumerate(dataloader)):
  File "/usr/local/lib/python3.10/dist-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 629, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 672, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 41, in fetch
    data = next(self.dataset_iter)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 150, in __next__
    return self._get_next()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 138, in _get_next
    result = next(self.iterator)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 222, in wrap_next
    result = next_func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/datapipe.py", line 383, in __next__
    return next(self._datapipe_iter)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 179, in wrap_generator
    response = gen.send(None)
  File "/localscratch/dgl-3/python/dgl/graphbolt/base.py", line 287, in __iter__
    yield from self.datapipe
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 179, in wrap_generator
    response = gen.send(None)
  File "/localscratch/dgl-3/python/dgl/graphbolt/base.py", line 274, in __iter__
    for data in self.datapipe:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 179, in wrap_generator
    response = gen.send(None)
  File "/usr/local/lib/python3.10/dist-packages/torchdata/datapipes/iter/util/prefetcher.py", line 103, in __iter__
    raise data
  File "/usr/local/lib/python3.10/dist-packages/torchdata/datapipes/iter/util/prefetcher.py", line 73, in thread_worker
    item = next(itr)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 179, in wrap_generator
    response = gen.send(None)
  File "/localscratch/dgl-3/python/dgl/graphbolt/base.py", line 346, in __iter__
    for data in self.datapipe:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 179, in wrap_generator
    response = gen.send(None)
  File "/localscratch/dgl-3/python/dgl/graphbolt/base.py", line 313, in __iter__
    for data in self.datapipe:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 179, in wrap_generator
    response = gen.send(None)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/iter/callable.py", line 124, in __iter__
    for data in self.datapipe:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 179, in wrap_generator
    response = gen.send(None)
  File "/localscratch/dgl-3/python/dgl/graphbolt/dataloader.py", line 95, in __iter__
    yield from self.dataloader
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 629, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 672, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 41, in fetch
    data = next(self.dataset_iter)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 150, in __next__
    return self._get_next()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 138, in _get_next
    result = next(self.iterator)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 222, in wrap_next
    result = next_func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/datapipe.py", line 383, in __next__
    return next(self._datapipe_iter)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 179, in wrap_generator
    response = gen.send(None)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/iter/callable.py", line 124, in __iter__
    for data in self.datapipe:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 179, in wrap_generator
    response = gen.send(None)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/iter/callable.py", line 124, in __iter__
    for data in self.datapipe:
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/_hook_iterator.py", line 179, in wrap_generator
    response = gen.send(None)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/iter/callable.py", line 125, in __iter__
    yield self._apply_fn(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/datapipes/iter/callable.py", line 90, in _apply_fn
    return self.fn(data)
  File "/localscratch/dgl-3/python/dgl/graphbolt/minibatch_transformer.py", line 38, in _transformer
    minibatch = self.transformer(minibatch)
  File "/localscratch/dgl-3/python/dgl/graphbolt/subgraph_sampler.py", line 78, in _sample
    ) = self.sample_subgraphs(
  File "/localscratch/dgl-3/python/dgl/graphbolt/impl/temporal_neighbor_sampler.py", line 76, in sample_subgraphs
    subgraph = self.sampler(
  File "/localscratch/dgl-3/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py", line 1106, in temporal_sample_neighbors
    self._check_sampler_arguments(nodes, fanouts, probs_or_mask)
  File "/localscratch/dgl-3/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py", line 779, in _check_sampler_arguments
    assert nodes.dtype == self.indices.dtype, (
AssertionError: Data type of nodes must be consistent with indices.dtype(torch.int32), but got torch.int64.
This exception is thrown by __iter__ of MiniBatchTransformer(datapipe=MiniBatchTransformer, transformer=<bound method SubgraphSampler._sample of TemporalNeighborSampler>)

Expected behavior

Should run without crash

Environment

Additional context

mfbalin commented 6 days ago

I think the dataset dtypes are not consistent, the nodes and the graph node dtypes do not match. @frozenbugs @Rhett-Ying

mfbalin commented 5 days ago

Regression tests do not show any result for the temporal link prediction example either.

mfbalin commented 5 days ago

Temporarily fixed in #7503, see the TODO: https://github.com/dmlc/dgl/blob/ec09676f8e102adfa172b5067b92f8bf4eea6a7d/examples/graphbolt/temporal_link_prediction.py#L298-L305