HIPS / autograd

Efficiently computes derivatives of NumPy code.
MIT License
7.01k stars 912 forks source link

Memory issue? #103

Closed hughsalimbeni closed 8 years ago

hughsalimbeni commented 8 years ago

I've run into an issue with large matrices and memory. There seem to be two problems:

1) Memory isn't being released on successive calls of grad. e.g.

a = 10000
b = 10000
A = np.random.randn(a)
B = np.random.randn(b)

def fn(x):
    M = A[:, na] + x[na, :]
    return M[0, 0]

g = grad(fn)

for i in range(100):
    g(B)

is ramping up memory on each iteration.

2) Memory isn't being released during the backwards pass e.g.

k = 10
def fn(x):
    res = 0
    for i in range(k):
        res = res + np.sum(x)
    return res
g = grad(fn)
b = 200000
g(np.random.randn(b))

This seems to scale in memory (for each call) as O(k), which don't think is the desired behaviour. For b=150000 this effect does not happen, however.

hughsalimbeni commented 8 years ago

I seem to have made some progress with (1) above by modifying the Node class:

class Node(object):
    __slots__ = ['value', 'tapes']
    Rnode = ReverseNode
    type_mappings = WeakKeyDictionary() # instead of {}
    def __init__(self, value, tapes):
        self.value = value
        self.tapes = WeakKeyDictionary() # instead of {}
        for tape in tapes:
            new_rnode = self.Rnode(type(self), value)
            tape.append(new_rnode)
            self.tapes[tape] = new_rnode

but this doesn't fix problem (2)

mattjj commented 8 years ago

Thanks for the excellent bug report. Also a helpful tip I learned from your earlier email: lines like print(psutil.virtual_memory().used / 1024**2) might help us track down the problems.

We had self.tapes as a WeakKeyDictionary at one point but it broke things when we factored the forward and backward pass to separate functions, c.f. #17. It may be the solution now, especially if the nose tests pass, but we should tread carefully.

dg-pb commented 8 years ago

I simply call garbage collector after each call.

jackkamm commented 8 years ago

Wow, this is very timely for me, issue (1) was causing me problems earlier today. The solution I was experimenting with was to adjust the backward pass to pop items off the tape and the parent_grad_ops (thus breaking cycle of reference between Node and ReverseNode):

    while tape:  # 2 lines replace "for node in tape[::-1]:"   
        node = tape.pop()
        if node.outgrads:
            cur_outgrad = node.sum_outgrads()
            assert type(new_node(getval(cur_outgrad))) == node.node_type, \
                "Types are {0} and {1}".format(type(new_node(getval(cur_outgrad))), node.node_type)
            #for gradfun, parent in node.parent_grad_ops:                                                                                                                                            
            while node.parent_grad_ops: # 2 lines replace "for gradfun,parent in node.parent_grad_ops"   
                gradfun, parent = node.parent_grad_ops.pop()
                og = cast_to_node_type(gradfun(cur_outgrad), parent.node_type, parent.node_value)
                parent.outgrads.append(og)
    return cur_outgrad

but this breaks the jacobian because it calls backward_pass multiple times (though this could be fixed by giving backward_pass an extra cleanup flag).

I think @hughsalimbeni 's solution is better (it doesn't break jacobian), but thought I'd mention this alternative.

hughsalimbeni commented 8 years ago

I've had a look at @jackkamm's solution for issue (2) and also thrown in a extra few weakrefs here and there, but it doesn't appear to make much of a difference. I should add that neither guppy nor gc is seeming to notice the problem

mattjj commented 8 years ago

The loopy references (which make the reference-counting garbage collection not work and which require some kind of mark-and-sweep strategy) are probably from the fact that nodes refer to tapes and tapes refer to nodes. The tapes aren't really a necessary data structure in the sense that they amount to a topological sorting of the computation graph (which is already encoded by the nodes pointing to each other), plus maybe a bit of extra bookkeeping (tape.complete). We might want to get rid of the tapes entirely.

A possible quick fix to try (that I don't see mentioned above) is to make the tapes keep weak references to nodes (as in a weak reference list).

mattjj commented 8 years ago

We just did some quick experiments and it seems that calling import gc; gc.collect() explicitly does seem to prevent prevent memory from building up.

Also, we think we identified the problem: originally, these lines in backward_pass

while tape:
    node = tape.pop()
    ...

ensured that there were no more circular dependencies when the backward pass finished by popping things off the tape, so that reference counting would be sufficient to clean up garbage. However, to support efficient jacobians I added this line:

tape = copy.copy(tape)  # <-- problem!
while tape:
    node = tape.pop()
    ...

which had the unintended result of preserving circular dependencies and hence preventing reference counting cleanup.

Still, after tape goes out of scope (when grad returns), gc.collect() should be able to garbage collect these things, and so adding explicit calls to gc.collect() to user-level code after calling grad should work. (We could also add it to the definition of grad.)

We're going to implement a fix (and probably eliminate circular references with tapes in the long term), but for now if you're having memory problems try calling gc.collect() in your code after computing a gradient.

mattjj commented 8 years ago

As for (2) (which I actually only just read the code for), reverse-mode autodiff generally requires keeping the forward values around in memory, so the memory usage should scale as \Theta(k) there. If you can rewrite code to generate fewer temporaries on the forward pass then of course memory usage would be reduced, but those temporaries can't be garbage collected until the backward pass is done. An optimizing compiler could in principle rewrite that particular code for you (because some temporaries are superfluous and so they don't need to be kept around), but autograd doesn't do any rewriting; it just follows the forward pass (keeping references to any temporaries generated so they can't be garbage collected).

hughsalimbeni commented 8 years ago

Thanks @mattjj for your previous comment. I've poked around for some more memory savings for my application and I've done a couple of things:

Firstly, after the line cur_outgrad = node.sum_outgrads() I've added node.outgrads = None which stops the derivatives accumulating in memory after being used (though I don't see why they weren't being deleted along with the node), and also node.node_value = 0. which frees up memory as it runs the tape. While this isn't really necessary I found it helpful in memory critical applications.

Secondly, and more importantly, following @mattjj's advice of creating fewer temporaries I changed things like np.sum(A[:, :, na] * B[:, na, :], 0) to np.einsum('ab,ac->bc', A, B) This made a huge (i.e. factor a) difference

ericmjl commented 8 years ago

I'm likewise facing a memory issue. Putting gc.collect()s in my code doesn't seem to take memory consumption down.

Line #    Mem usage    Increment   Line Contents
================================================
149     98.9 MiB      0.0 MiB   @profile
150                             def main():
151     99.0 MiB      0.0 MiB       print('Opening CSV file...')
152    106.7 MiB      7.7 MiB       df = open_csv_file()
153    106.7 MiB      0.0 MiB       print('Loading feature array...')
154    288.0 MiB    181.3 MiB       graph_array = load_feat_array()
155    288.0 MiB      0.0 MiB       print('Opening graph_idxs...')
156    313.5 MiB     25.5 MiB       graph_idxs = unpickle_data('../data/graph_idxs.pkl')
157    313.5 MiB      0.0 MiB       print('Opening graph_nodes...')
158    409.8 MiB     96.3 MiB       graph_nodes = unpickle_data('../data/graph_nodes.pkl')
159    409.8 MiB      0.0 MiB       print('Opening nodes_nbrs...')
160    638.5 MiB    228.7 MiB       nodes_nbrs = unpickle_data('../data/nodes_nbrs.pkl')
161
162                                 # Check data
163    638.5 MiB      0.0 MiB       print('Doing data checks...')
164    638.5 MiB      0.0 MiB       assert df.shape == (6660, 13)
165    638.5 MiB      0.0 MiB       assert len(graph_array) == 659895
166    638.5 MiB      0.0 MiB       assert len(graph_idxs) == len(graph_nodes)
167    638.5 MiB      0.0 MiB       assert len(nodes_nbrs) == len(graph_array)
168
169    638.5 MiB      0.0 MiB       print('Preprocessing data...')
170    638.5 MiB      0.0 MiB       pp_data = preprocess_data(df, nodes_nbrs, graph_idxs, graph_nodes,
171   1023.0 MiB    384.5 MiB                                 graph_array)
172    807.8 MiB   -215.2 MiB       graph_arr, nodes_nbrs, graph_idxs, graph_nodes = pp_data
173
174    807.8 MiB      0.0 MiB       assert graph_arr.shape[0] == len(nodes_nbrs)
175    807.8 MiB      0.0 MiB       assert len(graph_idxs) == len(graph_nodes)
176
177    807.8 MiB      0.0 MiB       print('Setting up neural net.')
178    807.8 MiB      0.0 MiB       layers = [GraphConvLayer(weights_shape=(36, 36),
179    807.8 MiB      0.0 MiB                                biases_shape=(1, 36)),
180    807.8 MiB      0.0 MiB                 FingerprintLayer(weights_shape=(36, 36),
181    807.8 MiB      0.0 MiB                                  biases_shape=(1, 36)),
182    807.8 MiB      0.0 MiB                 LinearRegressionLayer(weights_shape=(36, 1),
183    807.8 MiB      0.0 MiB                                       biases_shape=(1, 1)),
184                                           ]
185    807.8 MiB      0.0 MiB       print(layers)
186
187    807.8 MiB      0.0 MiB       print('Initializing network...')
188    807.9 MiB      0.0 MiB       wb = initialize_network(layers_spec=layers)
189    807.9 MiB      0.0 MiB       wb_vect, unflattener = flatten(wb)
190    807.9 MiB      0.0 MiB       print('Network initialized. Weights & biases:')
191    807.9 MiB      0.0 MiB       print(wb)
192
193    894.8 MiB     86.9 MiB       node_rows, node_cols, ones = to_sparse_format(nodes_nbrs)
194
195    894.8 MiB      0.0 MiB       nodes_nbrs_compressed = csr_matrix((ones, (node_rows, node_cols)),
196    894.8 MiB      0.0 MiB                                          shape=(len(nodes_nbrs),
197    906.2 MiB     11.5 MiB                                                 len(nodes_nbrs)))
198
199    906.2 MiB      0.0 MiB       train_losses = []
200    906.2 MiB      0.0 MiB       preds_iter = []
201    906.2 MiB      0.0 MiB       actual_iter = []
202
203    906.2 MiB      0.0 MiB       print('Defining train loss function.')
204
205   4438.9 MiB   3532.6 MiB       def train_loss(wb_vect, unflattener):
206   4438.9 MiB      0.0 MiB           wb_struct = unflattener(wb_vect)
207   4438.9 MiB      0.0 MiB           preds = predict(wb_struct, graph_arr, nodes_nbrs_compressed,
208   4232.0 MiB   -206.9 MiB                           graph_idxs, layers)
209   4232.9 MiB      0.9 MiB           graph_scores = get_actual(graph_idxs, df, preds)
210   4232.9 MiB      0.0 MiB           mse = np.mean(np.power(preds - graph_scores, 2))
211
212   4232.9 MiB      0.0 MiB           train_losses.append(mse)
213   4232.9 MiB      0.0 MiB           preds_iter.append(preds)
214   4232.9 MiB      0.0 MiB           actual_iter.append(graph_scores)
215   4232.9 MiB      0.0 MiB           gc.collect()
216   4232.9 MiB      0.0 MiB           return mse
217
218    906.2 MiB  -3326.7 MiB       traingrad = grad(train_loss)
219
220    906.2 MiB      0.0 MiB       training_losses = []
221
222    906.2 MiB      0.0 MiB       print('Defining callback function...')
223
224   4438.9 MiB   3532.6 MiB       def callback(wb, i):
225   4438.9 MiB      0.0 MiB           start = time()
226   3919.4 MiB   -519.5 MiB           tl = train_loss(*flatten(wb))
227   3919.4 MiB      0.0 MiB           if i % 1 == 0:
228   3919.4 MiB      0.0 MiB               print(tl, time() - start)
229   3919.4 MiB      0.0 MiB           training_losses.append(tl)
230   3919.4 MiB      0.0 MiB           gc.collect()
231
232    906.2 MiB  -3013.2 MiB       print('Training neural network.')
233   3919.4 MiB   3013.2 MiB       wb_vect, unflattener = adam(traingrad, wb, callback=callback, num_iters=10)
234
235   3919.4 MiB      0.0 MiB       print(wb_vect)

On my MacBook Air, in my Activity Monitor, I'm observing that the "Memory" minus the "Compressed Memory" stays roughly at about 1.5GB, but the absolute value of "Memory" and "Compressed Memory" continue to climb beyond the 8GB on my local machine. Is there a way out of this?

ericmjl commented 8 years ago

@hughsalimbeni: Implementing your changes to autograd helped fix my memory leakge issue as well.

jackkamm commented 8 years ago

@mattjj : would you be willing to incorporate @hughsalimbeni 's changes onto master? It's only 2 lines of code, makes a huge memory difference, and doesn't break anything (I have been using it for some months now). I can submit a PR if that would be helpful (or maybe @hughsalimbeni should submit the PR if he'd like the credit :).

My application involves computing a memory intensive M-estimator (sum of functions). For memory reasons I compute the gradient on minibatches instead of the final result. Having to call gc.collect() after every minibatch bogs things down quite a bit, and in my own recent tests was ineffective or only partially effective (agreeing with @ericmj 's finding).

I would much prefer to just list autograd as a dependency of my project, rather than have to make (and update and maintain) my own fork, and eventually require downstream users to use it, just over 2 lines of code.

mattjj commented 8 years ago

Thanks for bringing this up again. It's very useful feedback and we want to fix it.

As far as I can tell, you're talking about two changes:

  1. set node.outgrads = [] after node.sum_outgrads() is called in backward_pass
  2. set node.node_value = 0. after using it in backward_pass

Does that sound right to you?

I pushed a commit that might address these issues and also should mean you don't need to run gc.collect() explicitly, which can be very slow. The new commit implements change 1 above, but more importantly it addresses the circular dependency problem by adding a preserve_tape argument to backward_pass that is set to False for regular gradients and only True for jacobian (which now cleans up after itself when it's done with the tape). Reference counting should now be able to clean up everything.

I have a basic memory test, copied below, that runs successfully on the new commit but blew up virtual memory on the previous master. However, it might not cover your use case, so please try it out and let us know how things look.

import autograd.numpy as np
from autograd import grad
na = np.newaxis

a = 10000
b = 100
c = 50

A = np.random.randn(a)
B = np.random.randn(b)
C = np.random.randn(c)

def fn(x):
    return np.sum(A[:, na, na]*B[na, :, na]*x[na, na, :])

g = grad(fn)

for i in range(50):
    g(C)
jackkamm commented 8 years ago

Thanks for the commit, I tested it and it fixes the memory issues in my use case!

For the record, the change I was referencing and previously using was slightly different. In particular, I was using the WeakKeyDict as shown by @hughsalimbeni. This requires less code than your commit (literally 2 or 3 lines), but is also less explicit and I know you guys had some concerns about the brittleness of this solution.

Anyways, thanks again for fixing this, you guys rock!

mattjj commented 8 years ago

Making the nodes only have a weak reference to the tapes would be a more slick solution to the circular dependency problem, but as you say I am a bit wary of it (as in #17). Have you been using that solution without issue for a while?

The commit I added last night is also pretty short, and it arguably follows the explicit is better than implicit Zen of Python. Since it seems to solve things, I'm going to close this issue, but let me know if the WeakKeyDictionary has been working well for you and maybe I'll look into that solution.

jackkamm commented 8 years ago

I've been using the WeakKeyDict without issue for a few months now, ever since it was first proposed earlier in this thread. Also, I tried py.test in autograd/test with and without WeakKeyDict a few days ago, the output of the unit tests was the same.

But I am happy with your solution -- WeakKeyDict is bit too mysterious/magical for my tastes, and explicitly doing the dereferencing seems less likely to cause errors in the future.