Closed jermainewang closed 6 years ago
Hmm interesting question. From the perspective of networkx
I would expect that the DGLGraph
is mutable.
In fact, I will probably write the example above like this:
# GCN module
class GCN(nn.Module):
def __init__(self):
self.updmod = NodeModule()
def forward(self, nx_graph, features):
g = DGLGraph(nx_graph)
g.set_n_repr(features)
g.update_all(message_func='from_src',
reduce_func='sum',
update_func=self.updmod,
batchable=True)
return g.get_n_repr()
# main loop
nx_graph, features, labels = load_data('cora')
gcn = GCN()
for epoch in range(MAX_EPOCH):
logits = gcn(nx_graph, features)
loss = NLLLoss(logits, labels)
loss.backward()
# ...
EDIT: if nx_graph
is fixed, we can of course store it (and the DGLGraph
counterpart) as a module attribute during instantiation.
There are two concerns here: (1) Currently, we maintain a cached graph object to convert networkx graph to more efficient C++ graph structure. As a result, the cached graph needs to be reconstructed in each iteration. (2) It is still error-prone. As users are still free to write mutable codes (and possibly use them incorrectly).
(1) Based on my understanding of the implementation in dgl/cached_graph.py
, it seems that the cached graph is essentially a mapping between (src_id, dst_id)
and edge_id
. In that case, I guess we can keep the cached graph itself intact and only reinitialize the underlying tensor storages (we can't avoid it). Correct me if I'm wrong though.
(2) networkx.DiGraph
itself is mutable, and since DGLGraph
is a subclass of networkx.DiGraph
, I think it is more natural to have it mutable. The users should be aware of this (hence my code above)
Cached graph is an actual graph. It might be stored as CSR/CSC/COO depends on the configuration. Creating cached graph is a little bit costly as we need to convert the networkx graph to it (networkx stores graph in python dictionary). Currently, I don't know whether it is possible to lowering the conversion to C++ side, but since the conversion is O(V+E), we might assume it's costly.
In your code, since the DGLGraph is created and destroyed in the period of the forward function, there is no chance for us to save the cached graph. One way to solve this (as you've suggested), is to put the DGLGraph in the __init__
function as follows:
# GCN module
class GCN(nn.Module):
def __init__(self, nx_graph):
self.g = DGLGraph(nx_graph)
self.updmod = NodeModule()
def forward(self, features):
self.g.set_n_repr(features)
self.g.update_all(message_func='from_src',
reduce_func='sum',
update_func=self.updmod,
batchable=True)
return self.g.get_n_repr()
# main loop
nx_graph, features, labels = load_data('cora')
gcn = GCN(nx_graph)
for epoch in range(MAX_EPOCH):
# <------
# At the beginning of the second iteration, the `logits` tensor is still alive
# because the GCN object holds a reference to it.
logits = gcn(features)
loss = NLLLoss(logits, labels)
loss.backward()
# ...
However, there is another issue with this implementation. The updated repr is stored in the DGLGraph object which is stored in the GCN object. Because the object holds the reference to the repr tensor, it cannot be destroyed until it is replaced by the next iteration (as shown in the comment).
I don't entirely follow everything. For the GCN example, this bug is something the user should be aware of: each iteration creates a new layer of representation, the way the code is written is an in-place mutation.
The figure is in fact an example of immutable representation. In the above case, convolution is an operator that returns a new representation: y = conv(x)
. If you think graph conv is a conv that depends on graph structure, then the natural style should be y = g.update_all(x)
. By contrast, our style is:
g.set_n_repr(x)
g.update_all()
y = g.get_n_repr()
The difference is that g
is not stateless like conv
. It holds references to the representation that is being updated.
I think I'm lost... Why do we want to destroy the repr tensor or the DGLGraph
object before the beginning of the next iteration (especially considering that we never know how the tensors are going to be used anyway, and PyTorch always keep the tensors regardless unless there are no references pointing to it)?
In the case above where the graph is not changing, we can always reuse the same DGLGraph
object, only changing the underlying tensor storages during set_n_repr()
, right? And if the graph is changing (e.g. the algorithm starts from the same graph but gradually adds different stuff for each run), I guess the corresponding DGLGraph
object has to be reconstructed no matter what, correct?
@BarclayII This is at the beginning of a training epoch; set_n_repr is the same as loading the data.
@zzhang-cn I understand that. My question is regarding the fifth comment about the cached graph from @jermainewang
@BarclayII The tensor memory will be reclaimed if the reference count is zero. However, since the graph object always holds the reference, they cannot be reclaimed even if the variable appears to be reused. Consider two examples here, one in normal PyTorch, one in DGL:
X = # some initial
for epoch in range(MAX_EPOCH):
logits = my_model(X)
logits = 2 * logits # After this statement, the old 'logits' tensor will be reclaimed.
loss = my_loss(logits)
loss.backward()
X = # some initial
my_model = MyModel(nx_graph) # The DGLGraph is created in the object.
for epoch in range(MAX_EPOCH):
logits = my_model(X)
logits = 2 * logits # After this statement, the old 'logits' tensor is still
# there because the model holds a reference to it.
loss = my_loss(logits)
loss.backward()
First thing, I think in the first case the old logits
tensor can't be reclaimed, since PyTorch needs the old tensor for backpropagation.
Secondly, I personally don't view the tensors not being reclaimed immediately after forward phase as a problem. After all, we do need that much memory to update the graph, and maybe we need other fields in the graph later (e.g. to visualize graph attention).
"First thing, I think in the first case the old logits tensor can't be reclaimed, since PyTorch needs the old tensor for backpropagation."
No, it is not required. Computing the gradient of the old logit tensor is equal to 2 * the gradient of the new logit tensor.
How about having another PyTorch module instead of 2 * logits
? I don't know if it's worthwhile to optimize on this kind of cases.
It won't help. The only way is to explicitly delete/replace the column in the graph. I don't know whether it's worthwhile. It depends on the application. Just point this out that this might be a place to squeeze some memory consumption.
After a face-to-face meeting last weekend, we decide to keep the mutable semantics and provide pop_n_repr
and pop_e_repr
APIs to easily remove columns from the graph. Such change has been merged in #32 .
Here is a question we found when writing models using DGL. Let's first look at a following GCN example:
Can you spot the bug in this code?
The bug is that we need to reset the input features to the graph at the beginning of each iteration:
This is a very subtle mistake, but we feel that this worth our attention. The reason of such mistake is that we are very used to what autograd DNN frameworks provide us -- immutability. In Pytorch, all tensor operations are immutable (i.e, they always return new tensors). As a result, following Pytorch code is totally fine:
As a contrast, because DGL is derived from networkx, our APIs are mutable. For example, the
update_all
does not return a new graph. It in fact changes the node representations internally. Because of this, even if user writegg = gcn(g)
, thegg
andg
are pointing to the same DGLGraph and the node reprs have been updated afterupdate_all
.This is an inherent conflict of networkx and Pytorch/TF/MX. I want to know about the opinions from model developers (@GaiYu0 @BarclayII @ivanbrugere @zzhang-cn ). What do you guys think? Do you find this bug very subtle or actually a cunning pitfall? Do you like the current mutable graph or prefer immutable objects?
Here, I also want to share more about what if we want to support immutable graph object. What are the challenge and solution? To make DGLGraph immutable, we need to handle two things:
For transformations on node/edge features (e.g.
sendto
,recv
,update_all
, ...), we can return a new graph containing different data as follows. In the following example, note that theupdate_all
function returns a new graph g2. Initially the graph (i.e, g1) has three node features a1, a2, a3. The node update function reads all node features but only update a1 attribute. As a result, the new g2 node storage (called node frame) reuses the other two feature columns but use the newly generated a1 column. Such change to the system should have little overhead.For transformations on graph structures, this is a little tricky. For users that are used to networkx, they actually expect mutable behaviors as follows:
We can change it to immutable by using Copy-On-Write. The question is shall we?