Closed jermainewang closed 5 years ago
My PR is mainly to fix the issue in neighbor sampling. After the bug fix, I need to expose the sampling result so that I add a Python class. I'm not sure how to test the code without creating a class for the sampling result.
In general, I would agree on the concept and API. Some small questions:
class GCNLayer(nn.Module):
def __init__(self):
self.fc = nn.Linear(20, 20)
def forward(self, g, layerid, h): # type(g) == dgl.NodeFlow
g.ldata[layerid]['h'] = h # ldata shorts for g.layers[layerid].data['h']
g.pull(g.layers(layerid), fn.copy_src('h', 'm'), fn.sum('m', 'h')) # g.layers(layerid) gives node ids
return self.fc(g.ldata[layerid].pop('h'))
It seems you might want to use g.ldata[layerid-1]['h'] = h
?
Can we use layer_data
and flow_data
instead of ldata
and fdata
?
It seems you might want to use
g.ldata[layerid-1]['h']
= h?
You are right.
Can we use
layer_data
andflow_data
instead ofldata
andfdata
?
Try to be consistent with ndata
and edata
, but definitely could change that if it is too short to be clear.
My PR is mainly to fix the issue in neighbor sampling. After the bug fix, I need to expose the sampling result so that I add a Python class. I'm not sure how to test the code without creating a class for the sampling result.
How about having a separate PR first just about the NodeFlow
? We try to finish that PR ASAP. Then you could fix the neighbor sampling based on that.
It's a little akward. Without neighbor sampling, how are we going to test NodeFlow
?
Test basic interface. Just like how to test DGLGraph
. And I'm afraid such test is necessary because the NodeFlow
is actually a data structure visible to user, not only for sampling.
If it can be easily tested, it works for me.
NodeFlow
. Layer
in GCN community typically means a new layer with a different set of parameters; I would think slice
is better.slice
instead, we need use sdata
not ldata
fdata
can be confused with features in user's mind.
Layer
in GCN community typically means a new layer with a different set of parameters; I would thinkslice
is better.- if we use
slice
instead, we need usesdata
notldata
fdata
can be confused with features in user's mind.
I'm fine with slice
. In such case, I suggest follow Da's proposal: slice_data
and flow_data
to avoid confusion.
- are there cases where sampling needs to be done after one slice is computed? i.e., are there scenarios where a compelte, L-slice (or layer) nodeflow's construction is dynamic?
Interesting point. I am not aware of such scenario at the moment but seems totally reasonable. We actually thought about following append
API:
def append(nodeflow1, nodeflow2):
"""Return a new nodeflow that is [nodeflow1->nodeflow2]."""
pass
Note that the API must return a new nodeflow because nodeflow is an immutable data structure (for best efficiency).
The question is how efficient will this append
be? My guts feeling is it could be efficient:
Here is a pseudo-code for dynamic nodeflow?
nflow = dgl.NodeFlow() # initial empty flow
h = ... # initial representation
for i in range(L):
delta = compute_new_nodeflow(h)
nflow.append(delta)
h = GCNLayer(nflow, i)
- node relabeling, that's internal and opaque, correct? one use case can be that we want to follow the evolution of a node's hidden representation over time (like what Murphy does in the batch graph classification tutorial, but he's tracking them across layers).
Another interesting point. It seems that you want to keep the "pruned nodes" in the flow and activate them based on certain condition?
@zheng-da , that's why register
has layerid
as the argument.
I didn't notice layerid
.
I still find this API is inconvient to use. In my implementation of gcn_cv_updater, the first layer (the layer close to the input data) doesn't need msg_func and reduce_func because it has computed before, so I simply apply node_update_func.
maybe somthing like this:
def compute(self, flow_id, mfunc, rfunc, afunc):
"""Compute a flow using the given UDFs."""
pass
We should allow users to customize the computation for each layer. If the computation in a layer is very customized, the user needs to fall back to the original DGLGraph API.
Here is the code of what GCN with control variate + updater looks like.
Here is the trainer code:
class GCNLayer(gluon.Block):
def __init__(self, ind, in_feats, out_feats, activation, dropout, **kwargs):
super(GCNLayer, self).__init__(**kwargs)
self.ind = ind
self.in_feats = in_feats
self.node_update = NodeUpdate(out_feats, activation, dropout)
def forward(self, subg):
if self.ind == 0:
subg.layers[1].data['h'] = subg.layers[1].data['agg_h_0']
assert subg.layers[1].data['h'].shape[1] == self.in_feats
subg.apply_nodes(self.node_update, v=subg.layer_nid(1))
else:
subg.layers[self.ind].input_data['h'] = h
# control variate
subg.compute(self.ind + 1, partial(gcn_msg, ind=self.ind),
partial(gcn_reduce, ind=self.ind), self.node_update)
return subg.layers[self.ind + 1].data['h']
class GCN(gluon.Block):
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, **kwargs):
super(GCN, self).__init__(**kwargs)
self.n_layers = n_layers
assert n_layers >= 2
with self.name_scope():
#self.linear = gluon.nn.Dense(n_hidden, activation)
self.layers = gluon.nn.Sequential()
# input layer
self.layers.add(GCNLayer(0, in_feats, n_hidden, activation, dropout))
# hidden layers
for i in range(1, n_layers-1):
self.layers.add(GCNLayer(i, n_hidden, n_hidden, activation, dropout))
# output layer
self.layers.add(GCNLayer(n_layers-1, n_hidden, n_classes, None, dropout))
def forward(self, subg):
for i, layer in enumerate(self.layers):
h = layer(subg)
return h
Here is the updater code:
class GCNForwardLayer(gluon.Block):
def __init__(self, ind, in_feats, out_feats, activation, dropout, **kwargs):
super(GCNForwardLayer, self).__init__(**kwargs)
self.ind = ind
self.node_update = NodeUpdate(out_feats, activation, dropout)
def forward(self, g, h):
g.ndata['h'] = h
g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='h'))
g.ndata['h'] = g.ndata['h'] * g.ndata['deg_norm']
agg_h = g.ndata['h']
g.apply_nodes(self.node_update)
return agg_h, g.ndata.pop('h')
class GCNUpdate(gluon.Block):
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, **kwargs):
super(GCNUpdate, self).__init__(**kwargs)
self.n_layers = n_layers
assert n_layers >= 2
with self.name_scope():
#self.linear = gluon.nn.Dense(n_hidden, activation)
self.layers = gluon.nn.Sequential()
# input layer
self.layers.add(GCNForwardLayer(0, in_feats, n_hidden, activation, dropout))
# hidden layers
for i in range(1, n_layers-1):
self.layers.add(GCNForwardLayer(i, n_hidden, n_hidden, activation, dropout))
# output layer
self.layers.add(GCNForwardLayer(n_layers-1, n_hidden, n_classes, None, dropout))
def forward(self, g):
h = g.ndata['in']
for i, layer in enumerate(self.layers):
agg_h, h = layer(g, h)
g.ndata['h_%d' % (i + 1)] = h
g.ndata['agg_h_%d' % i] = agg_h
return h
It's unlikely to write the updater code the same as the trainer code, at least not for now. Writing data to a layer of NodeFlow requires index_copy, which generates a lot of unnecessary memory copy.
Comments for the traininer code:
class GCNLayer(gluon.Block): def __init__(self, ind, in_feats, out_feats, activation, dropout, **kwargs): super(GCNLayer, self).__init__(**kwargs) self.ind = ind self.in_feats = in_feats self.node_update = NodeUpdate(out_feats, activation, dropout) def forward(self, subg): if self.ind == 0: subg.layers[0].data['h'] = subg.layers[0].data['agg_h_0'] assert subg.layers[0].data['h'].shape[1] == self.in_feats subg.apply_nodes(self.node_update, v=subg.layer_nid(0)) else: subg.layers[self.ind - 1].data['h'] = h # control variate subg.pull(subg.layer_nid(self.ind), partial(gcn_msg, ind=self.ind), partial(gcn_reduce, ind=self.ind), self.node_update) return subg.layers[self.ind].data['h']
I see what you mean. How about this:
self.ind == 0
is actually not needed. Even if a node apply is required, this can be done outside of Nodeflow.With this in mind, the code will look like this (also fixed some syntax):
class GCNLayer(gluon.Block):
def __init__(self, ind, in_feats, out_feats, activation, dropout, **kwargs):
super(GCNLayer, self).__init__(**kwargs)
self.ind = ind
self.in_feats = in_feats
self.node_update = NodeUpdate(out_feats, activation, dropout)
def forward(self, nflow, h):
nflow.layer_data[self.ind-1]['h'] = h
# control variate
nflow.pull(nflow.layers(self.ind), fn.copy_src('h', 'm'), fn.sum('m', 'h'), self.node_update)
return nflow.layer_data[self.ind]['h']
And the GCN block:
class GCN(gluon.Block):
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, **kwargs):
super(GCN, self).__init__(**kwargs)
self.n_layers = n_layers
with self.name_scope():
self.layers = gluon.nn.Sequential()
# hidden layers
for i in range(0, n_layers-1):
self.layers.add(GCNLayer(i, n_hidden, n_hidden, activation, dropout))
# output layer
self.layers.add(GCNLayer(n_layers-1, n_hidden, n_classes, None, dropout))
def forward(self, nflow, h):
for layer in self.layers:
h = layer(nflow, h)
return h
Since such pattern is very common, we could provide the compute
interface as following. Note that the GCNLayer
class is fused to the GCN
class.
class GCN(gluon.Block):
def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, **kwargs):
super(GCN, self).__init__(**kwargs)
with self.name_scope():
self.mfunc = []
self.rfunc = []
self.afunc = gluon.nn.Sequential()
# hidden layers
for i in range(0, n_layers-1):
self.mfunc.append(fn.copy_src('h', 'm'))
self.rfunc.append(fn.sum('m', 'h'))
self.afunc.add(NodeUpdate(n_hidden, activation, dropout))
# output layer
self.mfunc.append(fn.copy_src('h', 'm'))
self.rfunc.append(fn.sum('m', 'h'))
self.afunc.add(NodeUpdate(n_classes, None, dropout))
def forward(self, nflow, h):
h = nflow.compute(h, self.mfunc, self.rfunc, self.afunc)
return h
Writing data to a layer of NodeFlow requires index_copy, which generates a lot of unnecessary memory copy
Isn't this something we should avoid in NodeFlow? I think the feature data of each layer should be stored separately, so nflow.layer_data[self.ind-1]['h'] = h
only replaces tensor reference.
I prototyped PinSage with the idea of NodeFlow here. That's the most natural implementation I can think of, given the current DGL release version (0.1.3).
I'm a little bit concerned that the implementation is not using the DGL message passing API. Are we at all going to handle cases like this in the NodeFlow system?
Writing data to a layer of NodeFlow requires index_copy, which generates a lot of unnecessary memory copy
Isn't this something we should avoid in NodeFlow? I think the feature data of each layer should be stored separately, so
nflow.layer_data[self.ind-1]['h'] = h
only replaces tensor reference.If so, I think we shouldn't inherit from DGLGraph. Otherwise, what is
ndata
?
Yeah, that's something should be discussed. What is the data structure for storing node/edge features? and how to support the DGLGraph
interface?
The other problem is that replacing ndata
with layer_data
, we need to change the scheduler that creates the execution plan. it seems it requires significant rewrite. i guess you or Linfan might be a better person to do so.
@BarclayII PinSage code seems to fit well in the NodeFlow API. Why do you think it can't use the DGL message passing API?
Suppose we use list of frames to store node/edge features:
class NodeFlow:
def __init__(self, ...):
self.graph_index = ...
self.num_layers = ...
self.layer_frames = [Frame() for _ in range(self.num_layers)]
self.flow_frames = [Frame() for _ in range(self.num_layers - 1)]
Some thoughts:
ndata
and edata
require some tweak. For example, to get the feature of node#10, we need to first find out which layer it resides in and then query the corresponding layer frame.apply_nodes([func, v, inplace]) | Apply the function on the nodes to update their features.
apply_edges([func, edges, inplace]) | Apply the function on the edges to update their features.
send([edges, message_func]) | Send messages along the given edges.
recv([v, reduce_func, …]) | Receive and reduce incoming messages and update the features of node(s) vv.
send_and_recv(edges[, …]) | Send messages along edges and let destinations receive them.
pull(v[, message_func, …]) | Pull messages from the node(s)’ predecessors and then update their features.
push(u[, message_func, …]) | Send message from the node(s) to their successors and update them.
update_all([message_func, …]) | Send messages through all edges and update all nodes.
prop_nodes(nodes_generator[, …]) | Propagate messages using graph traversal by triggering pull() on nodes.
prop_edges(edges_generator[, …]) | Propagate messages using graph traversal by triggering send_and_recv() on edges.
filter_nodes(predicate[, nodes]) | Return a tensor of node IDs that satisfy the given predicate.
filter_edges(predicate[, edges]) | Return a tensor of edge IDs that satisfy the given predicate.
You can see some APIs do not make sense in the NodeFlow scenario. As NodeFlow already represents a computation graph, it does not make sense to enable all of these. I feel the only reasonable APIs are propagating messages from layer i to layer j (j>i).
Thoughts?
Even the propagation needs tweak. not all layers have the same msg_func, red_func and node update func. Maybe we can create DGLBaseGraph that contains graph structure API. and create a new set of message passing API for NodeFlow.
apply_nodes(layer_id, func, inplace)
apply_edges(layer_id, func, inplace)
flow_compute(layer_id, msg_func, red_func, update_func)
I feel the same. How about:
apply_layers(layer_id, func, v, inplace)
apply_flows(flow_id, func, edges, inplace)
forward(msg_func, red_func, update_func, range=ALL)
range
can be one int, or slice like 1:3. If is slice type or ALL, the arguments should be lists.
@zzhang-cn I think layer
is better than slice
. Based on the discussion, it seems that except possibly for the input layer, each layer will have separate msg
, reduce
, apply_node
, apply_flow
functions, and likely to have a different set of parameters. This naming also matches what the authors of the papers tend to use (the i
th GNN layer).
Some other thoughts:
Not relevant at current stage:
Here is the updater code:
class GCNForwardLayer(gluon.Block): def __init__(self, ind, in_feats, out_feats, activation, dropout, **kwargs): super(GCNForwardLayer, self).__init__(**kwargs) self.ind = ind self.node_update = NodeUpdate(out_feats, activation, dropout) def forward(self, g, h): g.ndata['h'] = h g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='h')) g.ndata['h'] = g.ndata['h'] * g.ndata['deg_norm'] agg_h = g.ndata['h'] g.apply_nodes(self.node_update) return agg_h, g.ndata.pop('h') class GCNUpdate(gluon.Block): def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, **kwargs): super(GCNUpdate, self).__init__(**kwargs) self.n_layers = n_layers assert n_layers >= 2 with self.name_scope(): #self.linear = gluon.nn.Dense(n_hidden, activation) self.layers = gluon.nn.Sequential() # input layer self.layers.add(GCNForwardLayer(0, in_feats, n_hidden, activation, dropout)) # hidden layers for i in range(1, n_layers-1): self.layers.add(GCNForwardLayer(i, n_hidden, n_hidden, activation, dropout)) # output layer self.layers.add(GCNForwardLayer(n_layers-1, n_hidden, n_classes, None, dropout)) def forward(self, g): h = g.ndata['in'] for i, layer in enumerate(self.layers): agg_h, h = layer(g, h) g.ndata['h_%d' % (i + 1)] = h g.ndata['agg_h_%d' % i] = agg_h return h
It's unlikely to write the updater code the same as the trainer code, at least not for now. Writing data to a layer of NodeFlow requires index_copy, which generates a lot of unnecessary memory copy.
I don't understand why updater's logic isn't exactly the same as the vanilla GCN? Because NodeFlow will change across epoches (each epoch presumably builds a different DAG via sampling), so the conservative way would be just run GCN over the full graph. Or are you assuming updater is running over the NodeFlow of a new epoch?
Comments for the traininer code:
class GCNLayer(gluon.Block): def __init__(self, ind, in_feats, out_feats, activation, dropout, **kwargs): super(GCNLayer, self).__init__(**kwargs) self.ind = ind self.in_feats = in_feats self.node_update = NodeUpdate(out_feats, activation, dropout) def forward(self, subg): if self.ind == 0: subg.layers[0].data['h'] = subg.layers[0].data['agg_h_0'] assert subg.layers[0].data['h'].shape[1] == self.in_feats subg.apply_nodes(self.node_update, v=subg.layer_nid(0)) else: subg.layers[self.ind - 1].data['h'] = h # control variate subg.pull(subg.layer_nid(self.ind), partial(gcn_msg, ind=self.ind), partial(gcn_reduce, ind=self.ind), self.node_update) return subg.layers[self.ind].data['h']
I see what you mean. How about this:
- The first layer is always the input layer.
- Since there is no in-coming link to the input layer, there is no need to define computation for it. The node apply in
self.ind == 0
is actually not needed. Even if a node apply is required, this can be done outside of Nodeflow.With this in mind, the code will look like this (also fixed some syntax):
class GCNLayer(gluon.Block): def __init__(self, ind, in_feats, out_feats, activation, dropout, **kwargs): super(GCNLayer, self).__init__(**kwargs) self.ind = ind self.in_feats = in_feats self.node_update = NodeUpdate(out_feats, activation, dropout) def forward(self, nflow, h): nflow.layer_data[self.ind-1]['h'] = h # control variate nflow.pull(nflow.layers(self.ind), fn.copy_src('h', 'm'), fn.sum('m', 'h'), self.node_update) return nflow.layer_data[self.ind]['h']
And the GCN block:
class GCN(gluon.Block): def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, **kwargs): super(GCN, self).__init__(**kwargs) self.n_layers = n_layers with self.name_scope(): self.layers = gluon.nn.Sequential() # hidden layers for i in range(0, n_layers-1): self.layers.add(GCNLayer(i, n_hidden, n_hidden, activation, dropout)) # output layer self.layers.add(GCNLayer(n_layers-1, n_hidden, n_classes, None, dropout)) def forward(self, nflow, h): for layer in self.layers: h = layer(nflow, h) return h
Since such pattern is very common, we could provide the
compute
interface as following. Note that theGCNLayer
class is fused to theGCN
class.class GCN(gluon.Block): def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, **kwargs): super(GCN, self).__init__(**kwargs) with self.name_scope(): self.mfunc = [] self.rfunc = [] self.afunc = gluon.nn.Sequential() # hidden layers for i in range(0, n_layers-1): self.mfunc.append(fn.copy_src('h', 'm')) self.rfunc.append(fn.sum('m', 'h')) self.afunc.add(NodeUpdate(n_hidden, activation, dropout)) # output layer self.mfunc.append(fn.copy_src('h', 'm')) self.rfunc.append(fn.sum('m', 'h')) self.afunc.add(NodeUpdate(n_classes, None, dropout)) def forward(self, nflow, h): h = nflow.compute(h, self.mfunc, self.rfunc, self.afunc) return h
@mufeili @VoVAllen and I got together and looked through the codes. We assume the one we quoted is the most up to date; we also took a look at @BarclayII 's implementation of pinSage.
Couple of thoughts:
agg_h_0
). class GCNLayer(gluon.Block): def __init__(self, ind, in_feats, out_feats, activation, dropout, **kwargs): super(GCNLayer, self).__init__(**kwargs) self.ind = ind self.in_feats = in_feats self.node_update = NodeUpdate(out_feats, activation, dropout) def forward(self, subg): if self.ind == 0: subg.layers[1].data['h'] = subg.layers[1].data['agg_h_0'] assert subg.layers[1].data['h'].shape[1] == self.in_feats subg.apply_nodes(self.node_update, v=subg.layer_nid(1)) else: subg.layers[self.ind].input_data['h'] = h # control variate subg.compute(self.ind + 1, partial(gcn_msg, ind=self.ind), partial(gcn_reduce, ind=self.ind), self.node_update) return subg.layers[self.ind + 1].data['h']
Here are some more additional thoughts:
From @mufeili
Some other thoughts:
- For GNN training with sampling, we will be very likely to encounter the case of dynamic construction when some prioritized sampling is going to be made based on the latest node features. @jermainewang Clustering within GNN computation can also be a good example for supporting dynamic layer (flow) construction with NodeFlow.
Can append
API handle this?
- It might still be good to allow message passing within a layer particularly when the graph structure is static for some consecutive layers.
Don't understand.
From @zzhang-cn
I don't understand why updater's logic isn't exactly the same as the vanilla GCN? Because NodeFlow will change across epoches (each epoch presumably builds a different DAG via sampling), so the conservative way would be just run GCN over the full graph. Or are you assuming updater is running over the NodeFlow of a new epoch?
They should be the same, but I guess @zheng-da also considered the efficiency of using NodeFlow for full graph GCN. The issue the the memory consumption of the graph structure. Suppose we have a graph G(V, E). Using DGLGraph
, we only need to store the graph once so the memory consumption is O(|V|+|E|). However, if using NodeFlow
, since the graph structure is unfolded for L layers, the total memory consumption is O(L * (|V| + |E|)).
Note that the above is only for graph structure (e.g. csr or adjlist). The memory consumption of node/edge features is always O(L (|V| + |E|) d) because they all need to be saved.
Couple of thoughts:
- How will mini-batch training gets incorporated? (say, using PinSage as an example).
- Is the control variate correct? It looks like only the input layer uses it (see the use agg_h_0).
I think we haven't got the code including the training loop. @zheng-da could you put the code in a gist like Quan?
Another question from @zzhang-cn
The neighbors of a node in the DAG can be in 2 mutually exclusive states: 1) ignored (out-of-sample), or 2) instantiated. In the 2nd case, the value can be: a) an input if in the input layer, b) an output from the previous layer, and c) a stale historical data. I am under the impression that when a neighbor is ignored, the aggregation funciton needs to rescale, where do we handle it? Also, if the value is from historial, shall we expose it to the message function?
Drawed a figure below. See whether this makes sense to your guys:
Several notes:
copy_from_parent
API to fetch the feature data from the flat graph.NodeFlow
with the layer id.over the full graph. Or are you assuming updater is running over the NodeFlow of a new epoch?
The updater runs on the full graph. Its purpose is to prepare the history data. In my implementation, instead of computing history data on nodes, I also compute the aggregation of the history data so that we don't need all neighbors of a sampled node. In this way, we can use normal neighbor sampler directly.
As shown below, the white nodes are sampled. we store 'aggh%d' on white nodes and we don't need greys nodes in NodeFlow for training.
Does that means, in my figure, there is no blue node on layer0?
Here is the full implementation of GCN + CV + updater with the NodeFlow API. https://github.com/zheng-da/dgl-1/blob/fix_sampler1/examples/mxnet/gcn/gcn_cv_updater.py
@jermainewang It's not necessary. But it's really up to the user. We should give users to flexibility of implementing the algorithm in the way they like, but we can provide a recommended way of implementing it.
To implement control variate sampling, we have two choices. The first is to sample nodes and all of their neighbors, as shown in the figure above. It's doable to sample such a graph if we write a specific sampler for this algorithm. However, it's not very easy to write implement this. A potential approach is as follows:
class GCNLayer(gluon.Block):
def __init__(self, ind, in_feats, out_feats, activation, dropout, bias=True, **kwargs):
super(GCNLayer, self).__init__(**kwargs)
self.ind = ind
self.in_feats = in_feats
self.node_update = NodeUpdate(out_feats, activation, dropout)
def forward(self, subg, h):
# this is to aggregate history data of all neighbors.
subg.flow_compute(self.ind, fn.copy_src(src='h_%d' % ind, out='m'), fn.sum(msg='m', out='agg_h'),
lambda node : {'agg_h_%d' %ind : node.data['agg_h'] * node.data['deg_norm']})
# TODO how to get the active edges.
subg.flow_send_and_recv(self.ind, gcn_cv_msg, gcn_cv_reduce, self.node_update, active_edges)
return subg.layers[self.ind + 1].data['h']
The problem with this approach is that it's hard to distinguish the edges between the ones that connect to the sampled nodes (the white nodes) and the ones with history data (the grey nodes). The sampler needs to return extra information and our NodeFlow object should contain this information.
The full demo code is shown here: https://github.com/zheng-da/dgl-1/blob/fix_sampler1/examples/mxnet/gcn/gcn_cv_updater1.py
The other approach is to only sample the nodes that we perform computation and aggregate history data in advance, as show in the figure above. In my implementation, I use the updater to aggregate history data and store it in 'aggh%d'. We can also use other components to do so, e.g., the kvstore. If we do so, NodeFlow no longer needs the nodes only with history data (the greey nodes). We can use a normal neighbor sampler to generate NodeFlow. Then we can implement the algorithm as follows:
class GCNLayer(gluon.Block):
def __init__(self, ind, in_feats, out_feats, activation, dropout, bias=True, **kwargs):
super(GCNLayer, self).__init__(**kwargs)
self.ind = ind
self.in_feats = in_feats
self.node_update = NodeUpdate(out_feats, activation, dropout)
def forward(self, subg, h):
if self.ind == 0:
subg.layers[1].data['h'] = subg.layers[1].data['agg_h_0']
assert subg.layers[1].data['h'].shape[1] == self.in_feats
subg.apply_nodes(self.node_update)
else:
subg.layers[self.ind].data['h'] = h
# control variate
subg.flow_compute(self.ind, partial(gcn_cv_msg, ind=self.ind),
partial(gcn_cv_reduce, ind=self.ind), self.node_update)
return subg.layers[self.ind + 1].data['h']
This significantly simplies the implementation both the user code and DGL code. I think we should use the second approach to implement the control variate sampling. I don't know if it's necessary to support complex NodeFlow whose nodes and edges have different types.
The full demo code is here: https://github.com/zheng-da/dgl-1/blob/fix_sampler1/examples/mxnet/gcn/gcn_cv_updater.py
To get a better glance at how the aggregation data is used, please check the message passing APIs as follows:
def gcn_cv_msg(edge, ind):
msg = edge.src['h'] - edge.src['h_%d' % ind]
return {'m': msg}
def gcn_cv_reduce(node, ind):
accum = mx.nd.sum(node.mailbox['m'], 1) * node.data['norm'] + node.data['agg_h_%d' % ind]
return {'h': accum}
class NodeUpdate(gluon.Block):
def __init__(self, out_feats, activation=None, dropout=0, **kwargs):
super(NodeUpdate, self).__init__(**kwargs)
self.linear = gluon.nn.Dense(out_feats, activation=activation)
self.dropout = dropout
def forward(self, node):
accum = node.data['h']
if self.dropout:
accum = mx.nd.Dropout(accum, p=self.dropout)
accum = self.linear(accum)
return {'h': accum}
Here are the NodeFlow semantics I would like to propose. Each layer has its own id space. Nodes from the flat graph may appear in different layers multiple times, but can only appear in the same layer once. For simplicity, all nodes in a layer should have the same features/embeddings and perform the same computation.
I think this semantics should be sufficient for neighbor sampling, control variate sampling, layer-wise sampling and PinSage.
Here are the NodeFlow semantics I would like to propose.
What's the difference with the image above? (the blue/black nodes one)
The biggest difference is all nodes in a layer should have the same features/embeddings and perform the same computation. That is, we don't need to distinguish blue/black nodes, which makes it difficult to define computation on the nodes and edges.
From @mufeili
Some other thoughts:
- For GNN training with sampling, we will be very likely to encounter the case of dynamic construction when some prioritized sampling is going to be made based on the latest node features. @jermainewang Clustering within GNN computation can also be a good example for supporting dynamic layer (flow) construction with NodeFlow.
Can
append
API handle this?
I think so, a hard clustering algorithm can be used to cluster the nodes based on connectivity and node embeddings and then a new flow and layer can be added indicating the mapping between the original graph and the pooled/coarsened/reduced graph.
- It might still be good to allow message passing within a layer particularly when the graph structure is static for some consecutive layers.
Don't understand.
As you mentioned, since the graph structure is unfolded, we end up with O(L * (|V| + |E|)) memory consumption. There can be cases when several consecutive layers in a NodeFlow share the same graph structure and message passing (e.g. 2 layer full GCN within a NodeFlow). In such cases, it might reduce memory cost to still allow update within a layer before propagating through the next flow.
We use this issue to discuss about a proposal (by @zheng-da @BarclayII ) of a new API:
NodeFlow
.Motivation
In current DGL,
DGLGraph
is the core data structure to represent graph. It is very convenient to represent a flat data graph. For example, we useDGLGraph
to represent the citation graph fromCoraDataset
, the sentence parse tree fromSSTDataset
, etc.However, when it comes to defining multi-layer graph neural network, a flat graph is somewhat not ideal. For example, in current API, defining a multi-layer GCN is like follows:
Note that both
layer0
andlayer1
share the same graph structureg
. However, what if the two layers have different graph structures? For example, in graph sampling (control variance, neighbor sampling, etc.), messages are passed on different sampled subgraphs for each layer.Such hierarchical structure makes it hard for a flat
DGLGraph
designed for representing graph data to cleanly express the logic. In such case, the code needs to be adjusted:There are at least two awkwardness in the above implementation:
Proposal:
NodeFlow
NodeFlow is our proposed solution. It is a graph structure designed for hierarchical computation on graphs. Compared with
DGLGraph
:Below is a picture about how a 2-layer HeteroGCN is expressed using NodeFlow:
Notes:
NodeFlow
class still inheritsDGLGraph
to share its APIs, although the computation usually follows topological traversal.Here is a demo snippet of the HeteroGCN using NodeFlow:
APIs
First of all, all the API names are not finalized. Feel free to propose better names. For example, personally I found
NodePipeline
might be more intuitive thanNodeFlow
. The critical part is the concept that we should all agree on.There is an on-going effort by @zheng-da to implement this API (see #361 ). Right now, the API is buried in many other changes. @zheng-da could you separate the important class definition out of the PR? Also the current PR name is not explicit about this change.
The class definition is follows (I incorporate some of my thoughts in the API, so it might be different with the one in #361 ):
Utilities
To create a
NodeFlow
, we could suggest user to use some provided utility functions. For example,More realistic use case is by graph sampling APIs:
Storage
To achieve the best efficiency, the NodeFrame of layers should be separated. This is good for two reasons:
However, this requires some efforts in implementation if we want to make
NodeFlow
completely compatible withDGLGraph
's APIs. We could also disable APIs such asupdate_all
forNodeFlow
as it is somewhat meaningless. More ideas on this are appreciated.