davisyoshida / tf2-gradient-checkpointing

Simple gradient checkpointing for eager mode execution
MIT License
46 stars 7 forks source link

Gradient checkpointing usage #1

Open pidajay opened 4 years ago

pidajay commented 4 years ago

Hi @davisyoshida. I was looking into implementing my own version of gradient checkpointing in TF when I stumbled upon your repo. I tried to test your implementation but I was running into out of memory errors. Just wondering if I was using it as intended. Here is the code snippet. This is with the TF 2.2 nightly build.

@checkpointable
def get_model():
    model = tf.keras.Sequential()
    model.add(layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)))
    for i in range(8): # 9 pushes to oom
        model.add(layers.Dense(10000, activation='relu'))
    model.add(layers.Dense(10))
    return model

def testRecompute(bs):
    optimizer = optimizers.Adam()
    model = get_model()
    train_ds = mnist_dataset(bs)
    for step, (x, y) in enumerate(train_ds):
        with tf.GradientTape() as tape: 
            logits = model(x)
            loss  = compute_loss(logits, y)
            print('loss', loss)
        # compute gradient tf 2.x style
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
Kokonut133 commented 4 years ago

same question, differently formulated. Both dont seem to slow down training and GPU memory usage doesnt seem to have increased. Please let us know how to apply this.


def conv2d(input, filters, batch_norm, k_size=4):
 d = checkpointable(Conv2D(filters, kernel_size=k_size, strides=2, padding="same"))(input)
 return d

or 

@checkpointable
 def conv2d(input, filters, batch_norm, k_size=4):
  d = Conv2D(filters, kernel_size=k_size, strides=2, padding="same")(input)
  return d```
davisyoshida commented 4 years ago

@pidajay The issue here is that get_model is just called once. You want to apply the @checkpointable decorator to a callable that will actually be run during the training loop.

To elaborate on that and help with @Kokonut133's issue, what you should is to apply the decorator to a callable which has a decent amount of internal state that you want to avoid storing.

Lets say I have a layer which takes in a X and outputs Y, but doesn't really have any internal activations (I believe Conv2D falls into this category). Then, all applying the @checkpointable decorator will do is save X and rerun the layer during backprop, so we don't get any memory savings.

If I have 100 layers though, with inputs X1, X2, ..., X100, and outputs Y1, Y2, ..., Y100. You can break it up into blocks of 10 so that Y10 = block_of_layers(X1). Then, applying @checkpointable to block_of_layers will save X1, and throw away X2 through X9, and Y1 through Y9. This is where the memory savings come from.

Please let me know if I can clarify this further.

pidajay commented 4 years ago

@davisyoshida Thanks for the clarification. I figured something like that needs to be done. I implemented a version where the user just has to specify a single decorator (like in my sample above) and then a closure grad function will do the recompute for every layer during the back prop (no activations stored). Will try to create a TF PR next week. Lets see how it goes. I will keep you in the loop.

Kokonut133 commented 4 years ago

@pidajay I would be happy to see your implementation of this. If you could share a link, that would be great

@davisyoshida So as I understand, I did apply it correctly in both examples presented? I would prefer using the bottom version. Yet it wont lead to memory saving due to the nature of Conv2D layers having little inner activations to store? Would you by chance know how to reduce memory usage with Conv2D layers? Thanks already in advance.

davisyoshida commented 4 years ago

@pidajay For graph mode, it's definitely possible to do with a one off decorator, but in eager mode, it wasn't clear to me that there's a nice way to know which outputs to save. If you discard all layer outputs during the forward pass, you won't actually end up saving any memory during the backwards pass.

davisyoshida commented 4 years ago

@Kokonut133 There's a few things to fix. First, you need to pass _checkpoint=True to the decorated callable (so my_decorated_func(inputs, _checkpoint=True)). Second, for functions using variables which aren't passed as parameters, you'll need to pass the argument _watch_vars=list_of_variables. When using a Keras model, this is as easy as _watch_vars=self.variables.

In the examples you gave, the constructed Conv2D layer isn't saved, so now there's no reference which will let us access its variables. I'd recommend something like the following (this could be done a bit more cleanly in the context of a keras Model by the way):

First construct your layers:

layer_blocks = []
for i in range(10):
    block = []
    for j in range(10):
        block.append(Conv2D(filters, kernel_size=k_size, strides=2, padding="same"))
    layer_blocks.append(block)

Then make a function which will execute the block:

@checkpointable
def run_block(block, inputs):
    layer_output = inputs
    for layer in block:
        layer_output = layer(layer_output)
    return layer_output

Now you can execute your whole network

block_output = inputs
for block in layer_blocks:
    # In a keras Model, we could just use self.trainable_variables
    watch = [v for layer in block for v in layer.trainable_variables]
    block_output = run_block(block, block_output, _checkpoint=True, _watch_vars=watch)
pidajay commented 4 years ago

@davisyoshida From what I see it is actually possible in eager mode but a bit more involved. You just have to recompute outside the gradient tape so nothing gets saved. More details - during backward pass you need to recompute your current layer and previous layer's output and then calculate the gradients and feed them as output grads to the next layer for the backward pass. Essentially you are doing full back prop by yourself. As I said a bit more involved, but you abstract the user away from having to manually partition the model. Will try my best to find some time to push my code by end of this week. Will keep @Kokonut133 in the loop as well. Thanks!

davisyoshida commented 4 years ago

@pidajay My concern is about knowing what to save during the forward pass. For example, if I have a 100 layer network, I can save every 10th layer. If I have a 900 layer network, I can save every 30th. The problem is that in eager mode you can't tell which of those two situations you're in, so you won't know which layers should be saved and which should be dropped. On the other hand you can do a forward pass saving nothing, then a second forward pass where you save outputs, but that will take three forward passes instead of the two that gradient checkpointing usually takes.

On the other hand if you mean recompute the previous layer's output by running the network from the beginning, the runtime will be quadratic in the depth of the network instead of linear, which is likely not desirable.

Kokonut133 commented 4 years ago

@davisyoshida

So Itried implementing your solution and got to this:

    def build_discriminator(self, input_shape):
         @checkpointable
         def discriminator_layer(input, filters, batch_norm, f_size=4, _checkpoint=True):
             d = Conv2D(filters, kernel_size=f_size, strides=2, padding="same", activation="relu")(input)
             return d

        img = Input(shape=input_shape)
        n = input_shape[0]

        d1 = discriminator_layer(input=img, filters=n, batch_norm=False)
        d2 = discriminator_layer(input=d1, filters=n * 2, batch_norm=True)
        d3 = discriminator_layer(input=d2, filters=n * 4, batch_norm=True)
        d4 = discriminator_layer(input=d3, filters=n * 8, batch_norm=True)

        validity = Conv2D(1, kernel_size=4, strides=1, padding="same")(d4)

        layers = [d1, d2, d3, d4]
        watch = [v for layer in layers for v in layer.graph.trainable_variables]
        return Model([img_A, img_B], validity, _watch_vars=watch)

checkpointable and _checkpoint should be fitting now, yet the trainable variable parameter I am unable to pinpoint where to put it. I cant refer to self inside of this as the entire function to build the network is nested in another class. I am sorry if I am asking for common python knowledge but it would be quite helpful.

davisyoshida commented 4 years ago

@Kokonut133 I think it's slightly coming down to a misunderstanding of how the decorator works. The important thing here is that you want to decorate a function that will be called during the training loop. However, the discriminator layer function here is only called once during model construction.

I'll illustrate this below, using a single Conv2D layer as an example, but just to reiterate, there's no point in checkpointing a single layer, since you won't actually be saving any memory by doing so.

1) Start with undecorated callable that you'll be calling every training loop (e.g. a function or Keras Layer). my_layer = Conv2D(...) 2) Apply decorator: my_decorated_callable = checkpointable(my_layer) NOTE: This adds the _checkpoint and _watch_vars arguments to the callable. Specifying them yourself in the definition of the original won't do anything, as they'll be ignored. 3) Call decorated callable every training iteration:

for batch in my_dataset:
     output = my_decorated_callable(
         batch,
         _checkpoint=True,
         _watch_vars=my_layer.trainable_variables
     )

Now if you want to do this with blocks of layers, look back to my earlier comment again. The important thing is that I've decorated the function that executes the block of layers, not the part of the code that creates those layers to begin with.

By the way since this just grew out of something I was using for personal projects, I'm certainly open to suggestions for improvements both to usability, and documentation.

pidajay commented 4 years ago

@davisyoshida and @Kokonut133 . I have created my pull request over here - https://github.com/tensorflow/addons/pull/1600. @davisyoshida maybe this answers the questions you had for me. Let me know. Note - I have just implemented the recompute functionality. I am yet to do the checkpointing. But with this approach I don't think that should be complicated. Anyways, a huge thanks for creating this repo! It helped clarify a lot things for me. I have added this repo as a reference in my PR.

Kokonut133 commented 4 years ago

Ok. So, I believe to understand more. Yet, I am uncertain exactly how to apply it.

I have a Model (tf model) which i train with a custom train function. Inside of this train function, I call self.model.train_on_batch(x,y,z) (a function from keras). Now do I have to add @checkpointable before the functions definition in the directory where train_on_batch is defined or is there a way to overwrite locally train_on_batch with checkpointable?

I tried my_model.train_on_batch() = checkpointable(my_model.train_on_batch()) unsucessfully. Thank you for your help until now.

davisyoshida commented 4 years ago

@Kokonut133 To use the built in keras training options, the best thing to do would be make a custom Layer, which is actually 10 convolutional layers all bundled together or something. Have a function that executes all 10 layers in a row, and mark that function as checkpointable. In the call function of the custom Layer, call the function that executes the sub-layers, and pass _checkpoint=True, and _watch_vars=[v for sublayer in self.my_conv_layers for v in sublayer.trainable_variables].

Then, you can make a Sequential model from 10 of those compound layers (for a total of 100 conv layers).

You definitely don't need to edit any Keras code.

davisyoshida commented 4 years ago

@pidajay What's the policy it uses for selecting which layer outputs to save?

pidajay commented 4 years ago

@davisyoshida the current implementation does not do checkpointing yet i.e save the outputs for specific layers. It just recomputes for every layer during backward pass. I am working on the saving part. But the first version should be straightforward. The user passes in number of checkpoints (say num_checkpoints=10 for a 100 layer network). I space the checkpoints evenly across the network (so every 10th layer will be saved). During backward pass I just recompute from the closest checkpoint. So when I am at layer 77, I just recompute from layer 70.

davisyoshida commented 4 years ago

Ah if it's only going to be for Keras models that's kinda a bummer. If possible it would be nice for something more finegrained if this is what's going to end up being the official tensorflow solution. Have you looked at the pytorch checkpoint function?

pidajay commented 4 years ago

The idea is to support various flavors of Keras (sequential, functional, etc) provided there is enough traction. I don't see a point moving outside Keras though when TF as a whole is moving towards Keras.

davisyoshida commented 4 years ago

Well in general, it's much easier to make a Keras wrapper around a non-Keras specific feature, than to take something that's Keras based and use it when you need to do something lower level.

hermesonbf commented 3 years ago

I tried to add the checkpointing to my GAN as follows:

gene_wrapped = checkpointable(gene) # gene is the generator Model
disc_wrapped = checkpointable(disc) # disc is the discriminator Model

This runs inside the training loop:

...
with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
        g_tape.watch(gene.trainable_variables)
        d_tape.watch(disc.trainable_variables)

        fake_images = gene_wrapped(noise_images, training=True, _checkpoint=True, _watch_vars=gene.trainable_variables)

        real_logits = disc_wrapped(real_images, training=True, _checkpoint=True, _watch_vars=disc.trainable_variables)
        fake_logits = disc_wrapped(fake_images, training=True, _checkpoint=True, _watch_vars=disc.trainable_variables)

        g_loss = architecture.gene_loss(real_logits, fake_logits)
        d_loss = architecture.disc_loss(real_logits, fake_logits)
...

but no memory improvements at all :( am I doing anything wrong? @davisyoshida

davisyoshida commented 3 years ago

@left-brain Can you specify how you're benchmarking memory use?

hermesonbf commented 3 years ago

By the way, my IDE says: "Unused variable 'watch_args'" on this line: https://github.com/davisyoshida/tf2-gradient-checkpointing/blob/544d53f7711acc714caaa5482a2e7c187254e5d5/checkpointing.py#L24 is this an issue?

hermesonbf commented 3 years ago

Actually, for some reason I'm getting the following error: "No gradients provided for any variable" when running the code above...

davisyoshida commented 3 years ago

@left-brain The unused watch_args was just a mistake since I didn't delete that after an earlier refactor. As you can see the input args are watched on this line: https://github.com/davisyoshida/tf2-gradient-checkpointing/blob/544d53f7711acc714caaa5482a2e7c187254e5d5/checkpointing.py#L36

As for not getting any gradients, does this only occur when you are using the checkpointing decorator? I haven't seen that happen before, including when using the decorator with Keras models.

nyngwang commented 2 years ago

@davisyoshida Does this Repo. works on graph mode?

davisyoshida commented 2 years ago

@nyngwang You can just use tensorflow's built in stuff for graph mode. I made this since at the time, eager mode wasn't supported. (Not sure about the situation now since I haven't been using tensorflow since I switched to JAX).

nyngwang commented 2 years ago

@me You can just use tensorflow's built in stuff for graph mode.

@davisyoshida Did you mean the provided @tf.recompute_grad?


And now I'm facing a situation that needs some help from experts: I cannot tell whether those APIs (either the official one tf.recompute_grad, or those from GitHub by you, @pidajay, someone from google - @ppham27, someone that crop tf-slim - @mathemakitten, etc) are really saving GPU memory as intended. This is because I don't know what's the right tool to tell the difference in the results.

These are my questions: (all under the assumption that tensorflow 2.x is used, not 1.x versions)

  1. How to properly profile the GPU memory usage of tensorflow? I tried pidajay's method, where the unmaintained/buggy(negative memory in the report, e.g. issue) pythonprofilers/memory_profiler is used. From a reply by the author, it seems that the Repo. is for profiling CPU memory only (and pidajay also said he/she hadn't done the experiment for GPU/TPU in some PR for tensorflow - see the 2. of the link).
  2. So instead, I decided to use tensorboard to profile my GPU memory usage. I followed the idea of the ipynb-tutorial by pidajay that for each block/segment(containing many existing tf.keras.layers) to be recomputed, it should be wrapped within a tf.keras.Sequential for the current @tf.recompute_grad API to work). Then I tried many Repo.(auhors and links mentioned above) on GitHub for the API part. The following are two different reports from thensorboard that seems to mean the official API is working, but I cannot confirm it:

    • Fig 1.: The @tf.recompute_grad is not applied. The report has more rows than Fig 2.
    • Fig 2.: The @tf.recompute_grad is applied. But I don't know whether those Conv2DBackpropFilter would represent those call signatures and parameters saved during the forward pass for the recomputation on the backward pass.
image

image

davisyoshida commented 2 years ago

@nyngwang I can't really say I'm an expert on GPU profiling, but the way I knew whether these things were or weren't working was whether or not I could use them to run stuff for models which should fit in memory with checkpointing, but fail without. For example if you use 100 blocks of 100 layers (with shared weights), you should run out of GPU memory if they're large enough, but checkpointing that block of 100 layers should cut the activation memory required enough to be able to train such a model in a relatively small amount of memory.

nyngwang commented 2 years ago

@davisyoshida [...] I knew whether these things were or weren't working was whether or not I could use them to run stuff for models which should fit in memory with checkpointing, but fail without. [...]

I did. Unfortunately, it seems that the two cases, i.e. with/without checkpointing, cost the same GPU memory in my case. (I might incorrectly apply the related APIs, so I also submitted an issue to the official repo.)