dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.48k stars 3.01k forks source link

To mutable or not to mutable? #33

Closed jermainewang closed 6 years ago

jermainewang commented 6 years ago

Here is a question we found when writing models using DGL. Let's first look at a following GCN example:

import torch.nn as nn
from dgl import DGLGraph
# Node update module
class NodeModule(nn.module):
  def __init__(self):
    self.linear = nn.Linear(in_features=100, out_features=10)
  def forward(self, node, accum):
    return nn.relu(self.linear(accum))
# GCN module
class GCN(nn.Module):
  def __init__(self):
    self.updmod = NodeModule()
  def forward(self, g):
    g.update_all(message_func='from_src',
                 reduce_func='sum',
                 update_func=self.updmod,
                 batchable=True)
    return g
# main loop
nx_graph, features, labels = load_data('cora')
gcn = GCN()
g = DGLGraph(nx_graph)
g.set_n_repr(features)
for epoch in range(MAX_EPOCH):
  gg = gcn(g)
  logits = gg.get_n_repr()
  loss = NLLLoss(logits, labels)
  loss.backward()
  # ...

Can you spot the bug in this code?

The bug is that we need to reset the input features to the graph at the beginning of each iteration:

# ...
g = DGLGraph(nx_graph)
for epoch in range(MAX_EPOCH):
  g.set_n_repr(features)  # Need to reset the features every iter!!
  gg = gcn(g)
# ...

This is a very subtle mistake, but we feel that this worth our attention. The reason of such mistake is that we are very used to what autograd DNN frameworks provide us -- immutability. In Pytorch, all tensor operations are immutable (i.e, they always return new tensors). As a result, following Pytorch code is totally fine:

X = # init features
for epoch in range(MAX_EPOCH):
  y = MyModel(X)
  loss = MyLossFunc(y)
  loss.backward()

As a contrast, because DGL is derived from networkx, our APIs are mutable. For example, the update_all does not return a new graph. It in fact changes the node representations internally. Because of this, even if user write gg = gcn(g), the gg and g are pointing to the same DGLGraph and the node reprs have been updated after update_all.

This is an inherent conflict of networkx and Pytorch/TF/MX. I want to know about the opinions from model developers (@GaiYu0 @BarclayII @ivanbrugere @zzhang-cn ). What do you guys think? Do you find this bug very subtle or actually a cunning pitfall? Do you like the current mutable graph or prefer immutable objects?

Here, I also want to share more about what if we want to support immutable graph object. What are the challenge and solution? To make DGLGraph immutable, we need to handle two things:

For transformations on node/edge features (e.g. sendto, recv, update_all, ...), we can return a new graph containing different data as follows. In the following example, note that the update_all function returns a new graph g2. Initially the graph (i.e, g1) has three node features a1, a2, a3. The node update function reads all node features but only update a1 attribute. As a result, the new g2 node storage (called node frame) reuses the other two feature columns but use the newly generated a1 column. Such change to the system should have little overhead. image

For transformations on graph structures, this is a little tricky. For users that are used to networkx, they actually expect mutable behaviors as follows:

import networkx as nx
g = nx.path_graph(3)
print(g.edges())  # (0, 1), (1, 2)
h = g
h.add_edge(0, 2)
print(g.edges())  # (0, 1), (0, 2), (1, 2)

We can change it to immutable by using Copy-On-Write. The question is shall we?

BarclayII commented 6 years ago

Hmm interesting question. From the perspective of networkx I would expect that the DGLGraph is mutable.

In fact, I will probably write the example above like this:

# GCN module
class GCN(nn.Module):
  def __init__(self):
    self.updmod = NodeModule()
  def forward(self, nx_graph, features):
    g = DGLGraph(nx_graph)
    g.set_n_repr(features)
    g.update_all(message_func='from_src',
                 reduce_func='sum',
                 update_func=self.updmod,
                 batchable=True)
    return g.get_n_repr()
# main loop
nx_graph, features, labels = load_data('cora')
gcn = GCN()
for epoch in range(MAX_EPOCH):
  logits = gcn(nx_graph, features)
  loss = NLLLoss(logits, labels)
  loss.backward()
  # ...

EDIT: if nx_graph is fixed, we can of course store it (and the DGLGraph counterpart) as a module attribute during instantiation.

jermainewang commented 6 years ago

There are two concerns here: (1) Currently, we maintain a cached graph object to convert networkx graph to more efficient C++ graph structure. As a result, the cached graph needs to be reconstructed in each iteration. (2) It is still error-prone. As users are still free to write mutable codes (and possibly use them incorrectly).

BarclayII commented 6 years ago

(1) Based on my understanding of the implementation in dgl/cached_graph.py, it seems that the cached graph is essentially a mapping between (src_id, dst_id) and edge_id. In that case, I guess we can keep the cached graph itself intact and only reinitialize the underlying tensor storages (we can't avoid it). Correct me if I'm wrong though. (2) networkx.DiGraph itself is mutable, and since DGLGraph is a subclass of networkx.DiGraph, I think it is more natural to have it mutable. The users should be aware of this (hence my code above)

jermainewang commented 6 years ago

Cached graph is an actual graph. It might be stored as CSR/CSC/COO depends on the configuration. Creating cached graph is a little bit costly as we need to convert the networkx graph to it (networkx stores graph in python dictionary). Currently, I don't know whether it is possible to lowering the conversion to C++ side, but since the conversion is O(V+E), we might assume it's costly.

In your code, since the DGLGraph is created and destroyed in the period of the forward function, there is no chance for us to save the cached graph. One way to solve this (as you've suggested), is to put the DGLGraph in the __init__ function as follows:

# GCN module
class GCN(nn.Module):
  def __init__(self, nx_graph):
    self.g = DGLGraph(nx_graph)
    self.updmod = NodeModule()
  def forward(self, features):
    self.g.set_n_repr(features)
    self.g.update_all(message_func='from_src',
                 reduce_func='sum',
                 update_func=self.updmod,
                 batchable=True)
    return self.g.get_n_repr()
# main loop
nx_graph, features, labels = load_data('cora')
gcn = GCN(nx_graph)
for epoch in range(MAX_EPOCH):
  # <------
  #   At the beginning of the second iteration, the `logits` tensor is still alive
  #   because the GCN object holds a reference to it.
  logits = gcn(features)
  loss = NLLLoss(logits, labels)
  loss.backward()
  # ...

However, there is another issue with this implementation. The updated repr is stored in the DGLGraph object which is stored in the GCN object. Because the object holds the reference to the repr tensor, it cannot be destroyed until it is replaced by the next iteration (as shown in the comment).

zzhang-cn commented 6 years ago

I don't entirely follow everything. For the GCN example, this bug is something the user should be aware of: each iteration creates a new layer of representation, the way the code is written is an in-place mutation.

image

jermainewang commented 6 years ago

The figure is in fact an example of immutable representation. In the above case, convolution is an operator that returns a new representation: y = conv(x). If you think graph conv is a conv that depends on graph structure, then the natural style should be y = g.update_all(x). By contrast, our style is:

g.set_n_repr(x)
g.update_all()
y = g.get_n_repr()

The difference is that g is not stateless like conv. It holds references to the representation that is being updated.

BarclayII commented 6 years ago

I think I'm lost... Why do we want to destroy the repr tensor or the DGLGraph object before the beginning of the next iteration (especially considering that we never know how the tensors are going to be used anyway, and PyTorch always keep the tensors regardless unless there are no references pointing to it)?

In the case above where the graph is not changing, we can always reuse the same DGLGraph object, only changing the underlying tensor storages during set_n_repr(), right? And if the graph is changing (e.g. the algorithm starts from the same graph but gradually adds different stuff for each run), I guess the corresponding DGLGraph object has to be reconstructed no matter what, correct?

zzhang-cn commented 6 years ago

@BarclayII This is at the beginning of a training epoch; set_n_repr is the same as loading the data.

BarclayII commented 6 years ago

@zzhang-cn I understand that. My question is regarding the fifth comment about the cached graph from @jermainewang

jermainewang commented 6 years ago

@BarclayII The tensor memory will be reclaimed if the reference count is zero. However, since the graph object always holds the reference, they cannot be reclaimed even if the variable appears to be reused. Consider two examples here, one in normal PyTorch, one in DGL:

X = # some initial
for epoch in range(MAX_EPOCH):
  logits = my_model(X)
  logits = 2 * logits  # After this statement, the old 'logits' tensor will be reclaimed.
  loss = my_loss(logits)
  loss.backward()
X = # some initial
my_model = MyModel(nx_graph)  # The DGLGraph is created in the object.
for epoch in range(MAX_EPOCH):
  logits = my_model(X)
  logits = 2 * logits  # After this statement, the old 'logits' tensor is still
                       # there because the model holds a reference to it.
  loss = my_loss(logits)
  loss.backward()
BarclayII commented 6 years ago

First thing, I think in the first case the old logits tensor can't be reclaimed, since PyTorch needs the old tensor for backpropagation.

Secondly, I personally don't view the tensors not being reclaimed immediately after forward phase as a problem. After all, we do need that much memory to update the graph, and maybe we need other fields in the graph later (e.g. to visualize graph attention).

jermainewang commented 6 years ago

"First thing, I think in the first case the old logits tensor can't be reclaimed, since PyTorch needs the old tensor for backpropagation."

No, it is not required. Computing the gradient of the old logit tensor is equal to 2 * the gradient of the new logit tensor.

BarclayII commented 6 years ago

How about having another PyTorch module instead of 2 * logits? I don't know if it's worthwhile to optimize on this kind of cases.

jermainewang commented 6 years ago

It won't help. The only way is to explicitly delete/replace the column in the graph. I don't know whether it's worthwhile. It depends on the application. Just point this out that this might be a place to squeeze some memory consumption.

jermainewang commented 6 years ago

After a face-to-face meeting last weekend, we decide to keep the mutable semantics and provide pop_n_repr and pop_e_repr APIs to easily remove columns from the graph. Such change has been merged in #32 .