Open demmerichs opened 5 years ago
It seems like a combination of switch
(generated by tf.cond
) and resource variables. One thing to try could be to add more things to exclude into the list here:
https://github.com/cybertronai/gradient-checkpointing/blob/master/memory_saving_gradients.py#L85
Currently that list filters by op name, a better thing may be to filter by op type. IE, exclude anything like ReadVariableOp
, Switch
and Merge
. If you figure this out, a pull request would be appreciated!
Thanks for your insight, I will definitely give it a shoot on the weekend. Yesterday I also tried to come up with a minimal working example for my use case, but somehow I am not able to. Maybe I am still not fully understanding what is going on in TF or I am doing something wrong, so if you maybe just could answer my question here:
I have a layer with comparatively small input and output tensors but large intermediate tensors, which are kept in memory for backpropagation. This is what I tried recreating in my minimal example (see code below). I also attached an image of the graph. What I am not understanding is, that even with a high number of layers (n=128
) I do not run out of memory for the gradient computation. The forward pass should be fine for any number of n
given enough time, as it can safely delete the intermediate representations directly, however I thought TF would keep the intermediate tensors for backpropagation of the gradients? What am I missing?
import shutil
import os.path as osp
import numpy as np
import tensorflow as tf
def massive_layer(t):
with tf.name_scope('massive_layer'):
upsample = tf.tile(t[None, None, None], [32, 1024, 1024])
upsample = upsample + tf.reduce_mean(upsample)
reduce = tf.reduce_mean(upsample * 0.5)
return reduce
var = tf.Variable(np.random.normal(size=()), dtype=tf.float32)
d = var
n = 3
for i in range(n):
d = massive_layer(d)
grads = tf.gradients(d, var)[0]
shutil.rmtree(osp.join('/tmp', 'custom_gradients_testtb'), ignore_errors=True)
tb_saver = tf.summary.FileWriter(osp.join(
'/tmp', 'custom_gradients_testtb',
))
with tf.Session() as s:
s.run(tf.global_variables_initializer())
tb_saver.add_graph(s.graph)
print(s.run(d))
print(s.run([d, grads]))
TF only stores activations if you request gradients, if you don't request them in your session.run, it'll discard activations
I know. Still even when requesting the gradients (see last line of code example) it is not running out of memory for arbitrary sizes.
TensorFlow can store input of the grad function instead of output, ie it does this for ReLU. Not sure if that's the case for tile, but your input is quite input. Try stacking a couple of layers on top of each other
i have the same problem using tensorflow 1.15 and 1.14.
it all behaves pretty weird between versions. eg installing it via pip does not work (i think this is due how conda handles cuda).
To reproduce: use densenet from keras applications and train a classifier. Setup a) conda install tensorflow-gpu=1 conda install keras-gpu
this works while gradient_memory only gives sth between 3-4 memory increase (OOM with batch size 4 for my image size, while 1 works without memgrad)
Setup b) if you not install keras and switch all imports to tensorflow.keras you get the error reported by op. if you change
import memory_saving_gradients as gc
from tensorflow.python.ops import gradients as tf_gradients
tf_gradients.gradients = gc.gradients_speed
to
import tensorflow as tf
import memory_saving_gradients
# monkey patch tf.gradients to point to our custom version, with automatic checkpoint selection
tf.__dict__["gradients"] = memory_saving_gradients.gradients_speed
No errors, but there is no effect on memory size. This probably means it is not used
@yaroslavvb could you have a look and see what could be done (don't want to switch back to original keras since they will stop updating in april 2020)
@shsshs sorry, haven't kept up with TF backend changes (mostly stay in PyTorch-land nowadays), I'll be happy to merge any PR's that fix it though
It seems that change the word “/read” to “/Read” in line 90 and line92 works. Testing on tensorflow.python.keras.applications.InceptionV3, TF1.14.
fwd_ops = [op for op in fwd_ops if not '/Read' in op.name]
ts_all = ge.filter_ts(fwd_ops, True) # get the tensors
ts_all = [t for t in ts_all if '/Read' not in t.name]
I tried this out but get an Error when computing the gradients with the provided function using manually selected checkpoints. I get three different errors at the same time, and am not sure what of my graph is actually causing them, so I would appreciate some hints so that I could come up with a minimal non-working example. I currently use TF1.13.1 and especially the
tf.keras.layers.BatchNormalization
(just saying this because it pops up along the Error message). Is there any hope that this would be an easy fix?