davisyoshida / tf2-gradient-checkpointing

Simple gradient checkpointing for eager mode execution
MIT License
46 stars 7 forks source link

in-consistent grad value with gradient-checkpoint #2

Closed WingsOfPanda closed 3 years ago

WingsOfPanda commented 3 years ago

Hi @davisyoshida . I was trying to implement this gradient-checkpointing tech so that I could fit in a larger size input of my model. Here is the toy example I tried:

from gradient_checkpointing.checkpointing import checkpointable
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from tqdm import tqdm

gpus = tf.config.experimental.list_physical_devices('GPU')

for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

def forward_conv(x, filters, kernels, name='forward', padding='same'):
    i = 0
    for flt, kernel in zip(filters, kernels):
        x = layers.Conv3D(flt, kernel, activation='relu', padding=padding, dilation_rate=(1, 1, 1),
                          use_bias=False, name=str(i) + '_' + name)(x)
        x = layers.BatchNormalization(name=str(i) + '_bn_' + name)(x)
        i += 1
    return x

@checkpointable
def part_one(ipt):
    l1 = forward_conv(ipt, (4,), (3,), name='enc1')
    d2 = layers.MaxPool3D(pool_size=(2, 2, 2))(l1)
    l2 = forward_conv(d2, (4,), (3,), name='enc2')
    return l1, l2

@checkpointable
def part_two(ipt1, ipt2):
    l2 = forward_conv(ipt2, (4,), (3,), name='dec2')
    u1 = layers.UpSampling3D(size=[2, 2, 2])(l2)
    r1 = forward_conv(ipt1 + u1, (4,), (3,), name='dec1')
    return r1

initial = tf.ones([1, 16, 16, 16, 1], dtype=tf.float32)

tf.random.set_seed(1)
ckpt = False
with tf.GradientTape() as g:
    g.watch(initial)
    l1_, l2_ = part_one(initial, _checkpoint=ckpt)
    opt_ = part_two(l1_, l2_, _checkpoint=ckpt)
    loss = tf.reduce_mean(opt_)
    gd = g.gradient(loss, initial)
    print(f'loss is {loss} and grad is {np.sum(gd)} with ckpt = {ckpt}')

The example give me a wired results. That is, the grad I obtained was changed (actually quit a lot) when I change _checkpoint from True to False, and also I didn't see any memory usage reduce as well.

Is there anything wrong with my implementation? Any help would be really welcome!

davisyoshida commented 3 years ago

Looks like the issue is that you don't set the _force_seed parameter to the decorated calls. Modifying the code to the following produces equal gradients when you set ckpt to either True or False.

from itertools import repeat
...
ckpt = False
seeds = repeat(0)
with tf.GradientTape() as g:
    g.watch(initial)
    l1_, l2_ = part_one(initial, _checkpoint=ckpt, _force_seed=seeds)
    opt_ = part_two(l1_, l2_, _checkpoint=ckpt,  _force_seed=seeds)
    loss = tf.reduce_mean(opt_)
    gd = g.gradient(loss, initial)
    print(f'loss is {loss} and grad is {np.sum(gd)} with ckpt = {ckpt}')

You can just past True to _force_seed as well, and it'll make a random seed for each call from python's RNG. It's not a great solution so I'm open to suggestions tbh.

As for the memory, are you just looking at how much tensorflow allocates? I don't know anything about the policy it uses for memory growth, so I don't know whether you'll always see a measureable gain using that as the metric. For example if TF doesn't free the activations from the part_one before allocating the ones for part_two, I imagine the peak memory would be unchanged.

WingsOfPanda commented 3 years ago

Hi @davisyoshida thank you for your reply! yes, I managed to get the same returns by setting _force_seed=constant.

btw, is there any reason that u only allow float32 tensor to be computed for gradient? can I use this in mixed_precision for float16? thank you!

davisyoshida commented 3 years ago

I don't think anything should break, but I haven't tested with mixed precision stuff. You can just modify the relevant list comprehensions to enable that.

WingsOfPanda commented 3 years ago

thx, I think modify the following line in your code should do the trick?

flat_inputs = [x for x in flat_inputs if x.dtype == tf.float32 or x.dtype == tf.float16]

this is line 29 in your code

davisyoshida commented 3 years ago

That's right!

WingsOfPanda commented 3 years ago

@davisyoshida actually I have one more question. When we compute gradient, we need to compute gradients for all trainable variables so we can update them. However, in my toy code, like the part you quote below:

Looks like the issue is that you don't set the _force_seed parameter to the decorated calls. Modifying the code to the following produces equal gradients when you set ckpt to either True or False.

from itertools import repeat
...
ckpt = False
seeds = repeat(0)
with tf.GradientTape() as g:
    g.watch(initial)
    l1_, l2_ = part_one(initial, _checkpoint=ckpt, _force_seed=seeds)
    opt_ = part_two(l1_, l2_, _checkpoint=ckpt,  _force_seed=seeds)
    loss = tf.reduce_mean(opt_)
    gd = g.gradient(loss, initial)
    print(f'loss is {loss} and grad is {np.sum(gd)} with ckpt = {ckpt}')

You can just past True to _force_seed as well, and it'll make a random seed for each call from python's RNG. It's not a great solution so I'm open to suggestions tbh.

As for the memory, are you just looking at how much tensorflow allocates? I don't know anything about the policy it uses for memory growth, so I don't know whether you'll always see a measureable gain using that as the metric. For example if TF doesn't free the activations from the part_one before allocating the ones for part_two, I imagine the peak memory would be unchanged.

Here we only compute the gradient for variable 'initial' since we are 'watch' it. However, if we want to actually train this network, we need to compute the gradients for all trainable variables right? But how may I obtain those variables? Thank you!

davisyoshida commented 3 years ago

There's an example in the README:

layer = SomeKerasLayer()
wrapped_layer = checkpointable(layer)

with tf.GradientTape() as g:
    g.watch(layer.trainable_variables)
    output = wrapped_layer(*args, **kwargs, _checkpoint=True, _watch_vars=layer.trainable_variables)
print(g.gradient(output, layer.trainable_variables))
WingsOfPanda commented 3 years ago

@davisyoshida yea I saw this example sir and actually that is the part I got confused. My question is, what exactly are those 'SomeKerasLayer()'? Are they layers like 'Conv3D()', 'MaxPool3d()', or can be the 'part_one' or 'part_two' function I defined above?

I feel a bit confused now b/c I usually use Keras.Model() to create my model and do training and saving. However, looks to me, I now need to explicitly write my layers inside tf.GradientTape(). So, how do I save my model later? ><

Sorry to have many questions. If you could provide a working example to show a complete training process, that would be very helpful. Many thanks!

btw I am working on models like Unet, so many skip connections and hard to cut them into small models like Sequential models can...

davisyoshida commented 3 years ago

Ah so this is more of a TF question than a question about my script. I would recommend creating a subclass of keras.Layer or keras.Model outside of your training loop. There are a bunch of convenience functions to allow saving and loading of weights, and it also exposes the trainable_variables attribute which will give you all the variables declared in init including from any sublayers and their sublayers and so on.

WingsOfPanda commented 3 years ago

@davisyoshida ha ok I see. I guess where I got confused is at the following example that you provide before (which helps me a lot with how checkpoint working but also give me a bit confused on how to train)

import tensorflow as tf

from checkpointing import checkpointable

@checkpointable
def f(x, y, some_str, some_bool, z=None):
    for _ in range(200):
        x += y * z
    return x

initial = tf.ones(100000, dtype=tf.float32)
y = tf.ones(100000, dtype=tf.float32) + 1e-7
z = tf.ones(100000, dtype=tf.float32) + 1e-7
with tf.GradientTape() as g:
    g.watch(initial)
    x = initial
    for _ in range(200):
        x = f(x, y, 'a', True, z=z, _checkpoint=True)
    loss = tf.reduce_sum(x)
print(g.gradient(loss, x))

so, inside with tf.GradientTape() as g:, there are many variables there need to be trained, but we only watch the variable initial. So, could you demonstrate that how you obtain all trainable variables here? The 'variable' x is being rewritten all the time...

davisyoshida commented 3 years ago

You can call it on the object you create. And you can and should define the model outside of the GradientTape context. Like so:

class MyLayer(Layer):
    def __init__(self):
        self.some_var = tf.Variable(3.0)

    def call(self, inputs):
        return self.some_var * inputs

def main():
     layer = MyLayer()
     checkpointable_layer = checkpointable(layer)

    for example in dataset:
        with tf.GradientTape() as g:
             g.watch(layer.trainable_variables)
             loss = tf.reduce_sum(checkpointable_layer(example, _checkpoint=True, _watch_vars=layer.trainable_variables))

Again I think this is less about the checkpointing script, and more about Keras, so if there's any more confusion you could check out some of the tutorials on the tensorflow site.

WingsOfPanda commented 3 years ago

@davisyoshida thank you sir. Yea, I am new to keras (was a pure tf boy before). will try it. thanks a lot ^^

davisyoshida commented 3 years ago

Yeah if you want to do it in pure TF you'll need to manage the variables yourself and pass them in to the _watch_vars arg. It's doable but might be tedious.

WingsOfPanda commented 3 years ago

@davisyoshida right. that's part of the reason I'm trying to come to Keras, especially we I need to repeatedly use callable layers like the example below.

class MyLayer(keras.layers.Layer):
    def __init__(self):
        super(MyLayer, self).__init__()
        self.dense = layers.Dense(32, activation='relu')

    def call(self, inputs):
        return self.dense(inputs)

def main():
     layer = MyLayer()
     checkpointable_layer = checkpointable(layer)

    for example in dataset:
        with tf.GradientTape() as g:
             g.watch(layer.trainable_variables)
             x = checkpointable_layer(example, _checkpoint=True, _watch_vars=layer.trainable_variables)
             output = checkpointable_layer(x, _checkpoint=True, _watch_vars=layer.trainable_variables)
             loss = tf.reduce_sum(output)

I know the above example won't work, as if I want to use MyLayers multiple times (think MyLayer is a DenseBlock, and in densenet there are many of them stack together), then I need to define multiple times of MyLayers right? like MyLayers1, MyLayers2, and so on. and I ask g to watch `layer1.trainable_variables + layer2.trainable_variables'

I think here is the part I got confused all the time... I understand that checkpoint should work on each of those MyLayers to save memory (I would stack many dense layers inside each MyLayers). Hence, I need to explicitly define each 'MyLayers', decorate it, put it inside GradientTape, not just one function and call it several times... This whole process is a bit tedious so I wondering is there sth I miss?

davisyoshida commented 3 years ago

trainable_variables captures variables from sublayers so you could do something like this:


class Network:
    def __init__(self):
         self.layers = []
         for _ in range(5):
             self.layers.append(keras.Dense(10))
         self.checkpoint_layers = [checkpointable(l) for l in self.layers]

     def call(self, inputs):
          result = inputs
          for layer in self.checkpoint_layers:
               result = layer(result, _watch_vars=self.trainable_variables, _checkpoint=True)
          return result
WingsOfPanda commented 3 years ago

I c. will try. thank you very much!

WingsOfPanda commented 3 years ago

@davisyoshida Hi. I tried play with Keras a bit and I finally come up with the following. Could you please have a look that whether I am using checkpointable correctly? c/z I am not getting memory saving as I expected.


    input = keras.Input(shape=(16, 16, 16, 1))

    class MyLayer(keras.Model):
        def __init__(self, flts, kers):
            super(MyLayer, self).__init__()
            self.ops = [layers.Conv3D(flt, ker, use_bias=False, activation='relu', padding='same')
                        for flt, ker in zip(flts, kers)]

        def call(self, x):
            for layer in self.ops:
                x = layer(x)
            return x

    ckpt = True

    flts_ = (4,) * 6
    kers_ = (3,) * 6

    dense_layer1 = MyLayer(flts_, kers_)
    dense_layer1.build((None, 16, 16, 16, 1))
    checkpointable_dense1 = checkpointable(dense_layer1)
    x = checkpointable_dense1(input, _watch_vars=dense_layer1.trainable_variables, _checkpoint=ckpt)

    dense_layer2 = MyLayer(flts_, kers_)
    dense_layer2.build(x.shape)
    checkpointable_dense2 = checkpointable(dense_layer2)
    x2 = checkpointable_dense2(x, _watch_vars=dense_layer2.trainable_variables, _checkpoint=ckpt)

    output = tf.reduce_sum(x2)

    model = keras.Model(input, output, name='toy')
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

    # @tf.function
    def train():
        with tf.GradientTape() as tape:
            loss = model(tf.ones([1, 16, 16, 16]))

        grads = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

    ts = time.time()
    for epoch in tqdm(range(100)):
        train()
    print(f'training finished in {time.time()-ts}s')
davisyoshida commented 3 years ago

Can you explain what procedure you're using to benchmark your memory usage?

davisyoshida commented 3 years ago

I'm not sure if this will play nice with keras's keras.Model(input, output) paradigm, or how well it will interact with tf.function.

Try dropping the model bit for now, and move the call logic directly inside your train loop. If you want to make everything be inside one class so you can use convenience functions/attributes, just make a layer and define an explicit call function.

WingsOfPanda commented 3 years ago

Can you explain what procedure you're using to benchmark your memory usage?

I usually do this

gpus = tf.config.experimental.list_physical_devices('GPU')

# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

and watch nvidia-smi to see how much gpu the program is using. I know this is not a perfect way but at least it gives me some idea.

Also, I usually first set _checkpoint=False and slowly enlarger my model, for example, increase the input size bit by bit until OOM. Then, switch to _checkpoint=True to see whether I could omit the OOM situation. If I can, then means the checkpoint worked.

WingsOfPanda commented 3 years ago

@davisyoshida I changed to the following and it worked. I think keras.Model warped those layers again and somehow canceled the checkpoint setting (I also experienced that Keras.Model would produce error if used with tf.recompute_grad.)


    input = keras.Input(shape=(368, 368, 224, 1))

    class MyLayer(keras.Model):
        def __init__(self, flts, kers):
            super(MyLayer, self).__init__()
            self.ops = [layers.Conv3D(flt, ker, use_bias=False, activation='relu', padding='same')
                        for flt, ker in zip(flts, kers)]

        def call(self, x):
            for layer in self.ops:
                x = layer(x)
            return x

    ckpt = True

    flts_ = (4,) * 12
    kers_ = (3,) * 12

    layer1 = MyLayer(flts_, kers_)
    layer1.build((None, 368, 368, 224, 1))
    checkpointable_layer1 = checkpointable(layer1)

    layer2 = MyLayer(flts_, kers_)
    layer2.build((None, 368, 368, 224, 4))
    checkpointable_layer2 = checkpointable(layer2)

    layer3 = MyLayer(flts_, kers_)
    layer3.build((None, 368, 368, 224, 4))
    checkpointable_layer3 = checkpointable(layer3)

    layer4 = MyLayer(flts_, kers_)
    layer4.build((None, 368, 368, 224, 4))
    checkpointable_layer4 = checkpointable(layer4)

    layerf = MyLayer(flts_, kers_)
    layerf.build((None, 368, 368, 224, 4))
    checkpointable_layerf = checkpointable(layerf)

    trainable_weights = \
        layer1.trainable_variables + layer2.trainable_variables \
        + layer3.trainable_variables + layer4.trainable_variables + layerf.trainable_variables

    initial = tf.ones([1, 368, 368, 224, 1])
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    for epoch in tqdm(range(100)):
        with tf.GradientTape() as tape:
            x1 = checkpointable_layer1(initial, _watch_vars=layer1.trainable_variables, _checkpoint=ckpt)
            x2 = checkpointable_layer2(x1, _watch_vars=layer2.trainable_variables, _checkpoint=ckpt)
            x3 = checkpointable_layer3(x2, _watch_vars=layer3.trainable_variables, _checkpoint=ckpt)
            x4 = checkpointable_layer4(x3, _watch_vars=layer4.trainable_variables, _checkpoint=ckpt)
            xf = checkpointable_layerf(x4, _watch_vars=layerf.trainable_variables, _checkpoint=ckpt)
            output = tf.reduce_sum(xf)

            # loss = model()

        grads = tape.gradient(output, trainable_weights)
        optimizer.apply_gradients(zip(grads, trainable_weights))

The above model would result in OOM if set ckpt=False, and operational if ckpt=True. So I guess, implicitly, checkpointable worked. So, it looks to me that those checkpointable layers has to be placed directly under tf.GradientTape()?

One last thing is, now we put everything under tf.GradientTape(), how we save the model? The reason that I wrap everything with Keras.Model again is to save it. I can sure save each checkpointable_layer, but in the actual model that there gonna be many of them, so saving each would not be really practicable.

davisyoshida commented 3 years ago

You definitely do not need to construct the layers under the GradientTape. I'm pretty sure stuff like tf.function and the keras.Model construction trace the code and then end up "compiling" it in a way which removes the gradient checkpointing. I've personally never constructed my layers inside the GradientTape context. Again, I would recommend explicitly writing a layer's call function, rather than using the implicit model construction. There's no reason that constructing the layers inside the GradientTape context would be necessary, as initializing variables isn't a differentiable operation.

Please see here for an example: https://github.com/davisyoshida/tf2-gradient-checkpointing/issues/2#issuecomment-803509770

WingsOfPanda commented 3 years ago

trainable_variables captures variables from sublayers so you could do something like this:

class Network:
    def __init__(self):
         self.layers = []
         for _ in range(5):
             self.layers.append(keras.Dense(10))
         self.checkpoint_layers = [checkpointable(l) for l in self.layers]

     def call(self, inputs):
          result = inputs
          for layer in self.checkpoint_layers:
               result = layer(result, _watch_vars=self.trainable_variables, _checkpoint=True)
          return result

right! I think I understand more now. btw, in this example, u actually decorate every each layer but not a chunk of layers? I think I read in your git somewhere else that decorate every each layer won't save anything...

davisyoshida commented 3 years ago

Yeah this example is just to show the architecture, replace each of the Dense layers with some block of operations that it's actually worth checkpointing.

WingsOfPanda commented 3 years ago

right here is what I did in the end, have one keras.Model class (CkptLayer(keras.Model)) which support call function wrapped by checkpoint, and another one (class IntLayer(keras.Model)) to integrate/use it and save models. Is this the correct way? I now can squeeze this model into 24GB GPU by setting ckpt=True and would result in OOM if ckpt=False.

   class CkptLayer(keras.Model):
        def __init__(self, flts, kers):
            super(CkptLayer, self).__init__()
            self.ops = [layers.Conv3D(flt, ker, use_bias=False, activation='relu', padding='same')
                        for flt, ker in zip(flts, kers)]
            # self.ckpt_layers = [checkpointable(l) for l in self.ops]

        def call(self, ly, ckpt_on=False):
            @checkpointable
            def run_block(block, inputs):
                layer_output = inputs
                for layer in block:
                    layer_output = layer(layer_output)
                return layer_output

            ly = run_block(self.ops, ly, _checkpoint=ckpt_on, _watch_vars=self.trainable_variables)
            return ly

    ckpt = True

    flts_ = (4,) * 12
    kers_ = (3,) * 12

    base_size = (368, 368, 224, 4)
    build_size = (None,) + base_size

    class IntLayer(keras.Model):
        def __init__(self, CkptLayer, rpt):
            super(IntLayer, self).__init__()
            self.ops = [CkptLayer(flts_, kers_) for _ in range(rpt)]
            self.init = [l_.build(build_size) for l_ in self.ops]

        def call(self, ly, ckpt_on=False):
            for op in self.ops:
                ly = op(ly, ckpt_on=ckpt_on)
            return ly
davisyoshida commented 3 years ago

That works fine. It may be slightly slower than factoring run_block out into a separate sublayer, since then you could decorate it with tf.function. Right now, you can't do that since tf.function will be called every time you call the layer triggering a slow compiltaion.