stellargraph / stellargraph

StellarGraph - Machine Learning on Graphs
https://stellargraph.readthedocs.io/
Apache License 2.0
2.95k stars 431 forks source link

Deep Graph Infomax demo error: HinSAGE with multiple node types #1788

Open arglog opened 4 years ago

arglog commented 4 years ago

Describe the bug

In the Deep Graph Infomax demo. The function run_deep_graph_infomax if the base_model is HinSAGE with multiple node types.

To Reproduce

A minimal example to trigger the error

from stellargraph.mapper import (
    CorruptedGenerator,
    HinSAGENodeGenerator,
)
from stellargraph import StellarGraph
from stellargraph.layer import DeepGraphInfomax, HinSAGE
import pandas as pd

square_edges = pd.DataFrame(
    {"source": ["a", "b", "c", "d", "a"], "target": ["b", "c", "d", "a", "c"]}
)
square_foo = pd.DataFrame({"y": [0.5], "z": [50]}, index=["a"])
square_bar = pd.DataFrame(
    {"y": [0.4, 0.1, 0.9], "z": [100, 200, 300]}, index=["b", "c", "d"]
)
G = StellarGraph({"foo": square_foo, "bar": square_bar}, square_edges)

def run_deep_graph_infomax(
            base_model, generator, epochs, reorder=lambda sequence, subjects: subjects):
    corrupted_generator = CorruptedGenerator(generator)
    gen = corrupted_generator.flow(G.nodes())                     # <============= This line leads to an error
    infomax = DeepGraphInfomax(base_model, corrupted_generator)

hinsage_generator = HinSAGENodeGenerator(
    G, batch_size=1000, num_samples=[5], head_node_type="bar")
hinsage_model = HinSAGE(
    layer_sizes=[128], activations=["relu"], generator=hinsage_generator)
run_deep_graph_infomax(hinsage_model, hinsage_generator, epochs=10)

Observed behavior

Error message

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-5-f6c1390f7340> in <module>
     13     layer_sizes=[128], activations=["relu"], generator=hinsage_generator
     14 )
---> 15 run_deep_graph_infomax(hinsage_model, hinsage_generator, epochs=10)

<ipython-input-5-f6c1390f7340> in run_deep_graph_infomax(base_model, generator, epochs, reorder)
      3 ):
      4     corrupted_generator = CorruptedGenerator(generator)
----> 5     gen = corrupted_generator.flow(G.nodes())
      6     infomax = DeepGraphInfomax(base_model, corrupted_generator)
      7 

/data/work/yunan/software/anaconda3/lib/python3.7/site-packages/stellargraph/mapper/corrupted.py in flow(self, *args, **kwargs)
    117         """
    118         return CorruptedSequence(
--> 119             self.base_generator.flow(*args, **kwargs),
    120             self.corrupt_index_groups,
    121             self.base_generator.num_batch_dims(),

/data/work/yunan/software/anaconda3/lib/python3.7/site-packages/stellargraph/mapper/sampled_node_generators.py in flow(self, node_ids, targets, shuffle, seed)
    145         if len(invalid) > 0:
    146             raise ValueError(
--> 147                 f"node_ids: expected all nodes to be of type {expected_node_type}, "
    148                 f"found some nodes with wrong type: {comma_sep(invalid, stringify=format)}"
    149             )

ValueError: node_ids: expected all nodes to be of type bar, found some nodes with wrong type: 3

Expected behavior

Should not raise any error

Comments

The original example didn't crash because there is only one node type (paper) in the Cora dataset. My proposed solution is to add one more argument node_ids to the function run_deep_graph_infomax and replace the line

gen = corrupted_generator.flow(G.nodes())

with


gen = corrupted_generator.flow(node_ids)
huonw commented 4 years ago

Hi, thanks for filing an issue, but I'm a little confused! Are you suggesting the demo could be changed to make it easier to copy paste the code to run DGI with HinSAGE?

As you note, the correct thing to do is to only pass in nodes of the type head_node_type. This can be done with, for instance, G.nodes(node_type="bar") instead of G.nodes().