dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.56k stars 3.02k forks source link

HeteroGraphConv documentation sample code crash #7687

Open FatherOfDragons opened 3 months ago

FatherOfDragons commented 3 months ago

🐛 Bug

To Reproduce

Steps to reproduce the behavior: Following the code snippet example in HeteroGraphConv documentation with an actual implementation results in a crash HeteroGraphConv

import dgl
import dgl.graphbolt as gb
import dgl.nn as dglnn

import torch as th
import numpy as np

n_users = 50
n_games = 10
n_stores = 5

follows_src = np.random.randint(0, n_users, 100)
follows_dst = np.random.randint(0, n_users, 100)

plays_src = np.random.randint(0, n_users, 50)
plays_dst = np.random.randint(0, n_games, 50)

sells_src = np.random.randint(0, n_stores, 20)
sells_dst = np.random.randint(0, n_games, 20)

g = dgl.heterograph(
    {
        ("user", "follows", "user"): (follows_src, follows_dst),
        ("user", "plays", "game"): (plays_src, plays_dst),
        ("store", "sells", "game"): (sells_src, sells_dst),
    }
)

input_dim = 16
out_dim = 8

hetero_conv = dglnn.HeteroGraphConv(
    {
        "follows": dglnn.GraphConv(input_dim, out_dim),
        "plays": dglnn.GraphConv(input_dim, out_dim),
        "sells": dglnn.GraphConv(input_dim, out_dim),
    },
    aggregate="sum",
)

h1 = {"user": th.randn((g.num_nodes("user"), input_dim))}

h2 = hetero_conv(g, h1)
print(h2.keys())
python heterographconv_example.py
Traceback (most recent call last):
  File "/Users/yuri/remix/projects/easy-platform/platform-python/pinsage-recommender/heterographconv_example.py", line 44, in <module>
    h2 = hetero_conv(g, h1)
         ^^^^^^^^^^^^^^^^^^
  File "/Users/yuri/remix/projects/easy-platform/platform-python/pinsage-recommender/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/yuri/remix/projects/easy-platform/platform-python/pinsage-recommender/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/yuri/remix/projects/easy-platform/platform-python/pinsage-recommender/.venv/lib/python3.11/site-packages/dgl-2.3.0-py3.11-macosx-11.1-arm64.egg/dgl/nn/pytorch/hetero.py", line 212, in forward
    (inputs[stype], inputs[dtype]),
                    ~~~~~~^^^^^^^
KeyError: 'game'

Expected behavior

dict_keys(['user', 'game'])

Environment

Additional context

When 'game' features are added to h1 the example works Not sure, if this is an implementation bug or a documentation issue.
It seems reasonable to expect the example code to work as shown in documentation in case 'game' nodes don't have any intrinsic features Also, when a block is passed into the forward function instead of a graph, the failure is silent and the call returns an empty dictionary

rudongyu commented 3 months ago

In our current development, we assume all types of node features should be provided. We will update the documentation later.