cybertronai / gradient-checkpointing

Make huge neural nets fit in memory
MIT License
2.71k stars 270 forks source link

Gradient checkpointing seems to conflict with Keras batch norm #47

Open demmerichs opened 5 years ago

demmerichs commented 5 years ago

I tried this out but get an Error when computing the gradients with the provided function using manually selected checkpoints. I get three different errors at the same time, and am not sure what of my graph is actually causing them, so I would appreciate some hints so that I could come up with a minimal non-working example. I currently use TF1.13.1 and especially the tf.keras.layers.BatchNormalization (just saying this because it pops up along the Error message). Is there any hope that this would be an easy fix?

Traceback (most recent call last):                                                                                                                             
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 415, in _MaybeCompile                    
    xla_compile = op.get_attr("_XlaCompile")                                                                                                                              
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2413, in get_attr
    raise ValueError(str(e))                                                                                                                          
ValueError: Operation 'optimizer/head/convolve_batch_activate_20/batch_normalization_v1_21/cond/ReadVariableOp_1/Switch' has no attr named '_XlaCompile'.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 455, in _apply_op_helper
    as_ref=input_arg.is_ref)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1240, in internal_convert_n_to_tensor
    ctx=ctx))
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1175, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 977, in _TensorTensorConversionFunction
    (dtype.name, t.dtype.name, str(t)))
ValueError: Tensor conversion requested dtype float32 for Tensor with dtype resource: 'Tensor("optimizer/gradients/optimizer/head/convolve_batch_activate_20/batch_normalization_v1_21/cond/ReadVariableOp_1/Switch
_grad/Switch_1:1", shape=(), dtype=resource)'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "./src/sadt.py", line 544, in <module>
    with SpaceAndDeformableTimeNetwork(cfg, datasets) as exp:
  File "/lhome/davidj2/code/sync/space_and_deformable_time/src/xxsflow/experiments/base_experiment.py", line 42, in __enter__
    self.build_graph()
  File "./src/sadt.py", line 298, in build_graph
    self.optimizer_op = self.optimizer
  File "/lhome/davidj2/code/sync/space_and_deformable_time/src/xxsflow/utils.py", line 388, in wrapped_function
    setattr(self, attribute, function(self))
  File "./src/sadt.py", line 267, in optimizer
    grads = grads = tf.gradients(self.loss, tf.trainable_variables())
  File "/lhome/davidj2/code/sync/space_and_deformable_time/packages/gradient_checkpointing/memory_saving_gradients.py", line 40, in gradients_collection
    return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/packages/gradient_checkpointing/memory_saving_gradients.py", line 227, in gradients
    dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/packages/gradient_checkpointing/memory_saving_gradients.py", line 27, in tf_gradients
    return tf_gradient_function(ys, *args, **kwargs)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 664, in gradients
    unconnected_gradients)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 965, in _GradientsHelper
    lambda: grad_fn(op, *out_grads))
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 420, in _MaybeCompile
    return grad_fn()  # Exit early
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gradients_impl.py", line 965, in <lambda>
    lambda: grad_fn(op, *out_grads))
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_grad.py", line 88, in _SwitchGrad
    return merge([false_grad, true_grad])[0], None
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 466, in merge
    return gen_control_flow_ops.merge(inputs, name)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/ops/gen_control_flow_ops.py", line 418, in merge
    "Merge", inputs=inputs, name=name)
  File "/lhome/davidj2/code/sync/space_and_deformable_time/.venv/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 483, in _apply_op_helper
    raise TypeError("%s that don't all match." % prefix)
TypeError: Tensors in list passed to 'inputs' of 'Merge' Op have types [float32, resource] that don't all match.
yaroslavvb commented 5 years ago

It seems like a combination of switch (generated by tf.cond) and resource variables. One thing to try could be to add more things to exclude into the list here:

https://github.com/cybertronai/gradient-checkpointing/blob/master/memory_saving_gradients.py#L85

Currently that list filters by op name, a better thing may be to filter by op type. IE, exclude anything like ReadVariableOp, Switch and Merge. If you figure this out, a pull request would be appreciated!

demmerichs commented 5 years ago

Thanks for your insight, I will definitely give it a shoot on the weekend. Yesterday I also tried to come up with a minimal working example for my use case, but somehow I am not able to. Maybe I am still not fully understanding what is going on in TF or I am doing something wrong, so if you maybe just could answer my question here:

I have a layer with comparatively small input and output tensors but large intermediate tensors, which are kept in memory for backpropagation. This is what I tried recreating in my minimal example (see code below). I also attached an image of the graph. What I am not understanding is, that even with a high number of layers (n=128) I do not run out of memory for the gradient computation. The forward pass should be fine for any number of n given enough time, as it can safely delete the intermediate representations directly, however I thought TF would keep the intermediate tensors for backpropagation of the gradients? What am I missing?

import shutil
import os.path as osp
import numpy as np
import tensorflow as tf

def massive_layer(t):
    with tf.name_scope('massive_layer'):
        upsample = tf.tile(t[None, None, None], [32, 1024, 1024])
        upsample = upsample + tf.reduce_mean(upsample)
        reduce = tf.reduce_mean(upsample * 0.5)
        return reduce

var = tf.Variable(np.random.normal(size=()), dtype=tf.float32)
d = var
n = 3
for i in range(n):
    d = massive_layer(d)
grads = tf.gradients(d, var)[0]

shutil.rmtree(osp.join('/tmp', 'custom_gradients_testtb'), ignore_errors=True)
tb_saver = tf.summary.FileWriter(osp.join(
    '/tmp', 'custom_gradients_testtb',
))

with tf.Session() as s:
    s.run(tf.global_variables_initializer())
    tb_saver.add_graph(s.graph)
    print(s.run(d))
    print(s.run([d, grads]))

image

yaroslavvb commented 5 years ago

TF only stores activations if you request gradients, if you don't request them in your session.run, it'll discard activations

demmerichs commented 5 years ago

I know. Still even when requesting the gradients (see last line of code example) it is not running out of memory for arbitrary sizes.

yaroslavvb commented 5 years ago

TensorFlow can store input of the grad function instead of output, ie it does this for ReLU. Not sure if that's the case for tile, but your input is quite input. Try stacking a couple of layers on top of each other

shsshs commented 4 years ago

i have the same problem using tensorflow 1.15 and 1.14.

it all behaves pretty weird between versions. eg installing it via pip does not work (i think this is due how conda handles cuda).

To reproduce: use densenet from keras applications and train a classifier. Setup a) conda install tensorflow-gpu=1 conda install keras-gpu

this works while gradient_memory only gives sth between 3-4 memory increase (OOM with batch size 4 for my image size, while 1 works without memgrad)

Setup b) if you not install keras and switch all imports to tensorflow.keras you get the error reported by op. if you change

import memory_saving_gradients as gc
from tensorflow.python.ops import gradients as tf_gradients
tf_gradients.gradients = gc.gradients_speed

to

import tensorflow as tf
import memory_saving_gradients
# monkey patch tf.gradients to point to our custom version, with automatic checkpoint selection
tf.__dict__["gradients"] = memory_saving_gradients.gradients_speed

No errors, but there is no effect on memory size. This probably means it is not used

shsshs commented 4 years ago

@yaroslavvb could you have a look and see what could be done (don't want to switch back to original keras since they will stop updating in april 2020)

yaroslavvb commented 4 years ago

@shsshs sorry, haven't kept up with TF backend changes (mostly stay in PyTorch-land nowadays), I'll be happy to merge any PR's that fix it though

purpleyun commented 4 years ago

It seems that change the word “/read” to “/Read” in line 90 and line92 works. Testing on tensorflow.python.keras.applications.InceptionV3, TF1.14.

    fwd_ops = [op for op in fwd_ops if not '/Read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True) # get the tensors
    ts_all = [t for t in ts_all if '/Read' not in t.name]