Closed WingsOfPanda closed 3 years ago
Looks like the issue is that you don't set the _force_seed
parameter to the decorated calls. Modifying the code to the following produces equal gradients when you set ckpt
to either True
or False
.
from itertools import repeat
...
ckpt = False
seeds = repeat(0)
with tf.GradientTape() as g:
g.watch(initial)
l1_, l2_ = part_one(initial, _checkpoint=ckpt, _force_seed=seeds)
opt_ = part_two(l1_, l2_, _checkpoint=ckpt, _force_seed=seeds)
loss = tf.reduce_mean(opt_)
gd = g.gradient(loss, initial)
print(f'loss is {loss} and grad is {np.sum(gd)} with ckpt = {ckpt}')
You can just past True
to _force_seed
as well, and it'll make a random seed for each call from python's RNG. It's not a great solution so I'm open to suggestions tbh.
As for the memory, are you just looking at how much tensorflow allocates? I don't know anything about the policy it uses for memory growth, so I don't know whether you'll always see a measureable gain using that as the metric. For example if TF doesn't free the activations from the part_one before allocating the ones for part_two, I imagine the peak memory would be unchanged.
Hi @davisyoshida thank you for your reply! yes, I managed to get the same returns by setting _force_seed=constant.
btw, is there any reason that u only allow float32 tensor to be computed for gradient? can I use this in mixed_precision for float16? thank you!
I don't think anything should break, but I haven't tested with mixed precision stuff. You can just modify the relevant list comprehensions to enable that.
thx, I think modify the following line in your code should do the trick?
flat_inputs = [x for x in flat_inputs if x.dtype == tf.float32 or x.dtype == tf.float16]
this is line 29 in your code
That's right!
@davisyoshida actually I have one more question. When we compute gradient, we need to compute gradients for all trainable variables so we can update them. However, in my toy code, like the part you quote below:
Looks like the issue is that you don't set the
_force_seed
parameter to the decorated calls. Modifying the code to the following produces equal gradients when you setckpt
to eitherTrue
orFalse
.from itertools import repeat ... ckpt = False seeds = repeat(0) with tf.GradientTape() as g: g.watch(initial) l1_, l2_ = part_one(initial, _checkpoint=ckpt, _force_seed=seeds) opt_ = part_two(l1_, l2_, _checkpoint=ckpt, _force_seed=seeds) loss = tf.reduce_mean(opt_) gd = g.gradient(loss, initial) print(f'loss is {loss} and grad is {np.sum(gd)} with ckpt = {ckpt}')
You can just past
True
to_force_seed
as well, and it'll make a random seed for each call from python's RNG. It's not a great solution so I'm open to suggestions tbh.As for the memory, are you just looking at how much tensorflow allocates? I don't know anything about the policy it uses for memory growth, so I don't know whether you'll always see a measureable gain using that as the metric. For example if TF doesn't free the activations from the part_one before allocating the ones for part_two, I imagine the peak memory would be unchanged.
Here we only compute the gradient for variable 'initial' since we are 'watch' it. However, if we want to actually train this network, we need to compute the gradients for all trainable variables right? But how may I obtain those variables? Thank you!
There's an example in the README:
layer = SomeKerasLayer()
wrapped_layer = checkpointable(layer)
with tf.GradientTape() as g:
g.watch(layer.trainable_variables)
output = wrapped_layer(*args, **kwargs, _checkpoint=True, _watch_vars=layer.trainable_variables)
print(g.gradient(output, layer.trainable_variables))
@davisyoshida yea I saw this example sir and actually that is the part I got confused. My question is, what exactly are those 'SomeKerasLayer()'? Are they layers like 'Conv3D()', 'MaxPool3d()', or can be the 'part_one' or 'part_two' function I defined above?
I feel a bit confused now b/c I usually use Keras.Model() to create my model and do training and saving. However, looks to me, I now need to explicitly write my layers inside tf.GradientTape(). So, how do I save my model later? ><
Sorry to have many questions. If you could provide a working example to show a complete training process, that would be very helpful. Many thanks!
btw I am working on models like Unet, so many skip connections and hard to cut them into small models like Sequential models can...
Ah so this is more of a TF question than a question about my script.
I would recommend creating a subclass of keras.Layer or keras.Model outside of your training loop. There are a bunch of convenience functions to allow saving and loading of weights, and it also exposes the trainable_variables
attribute which will give you all the variables declared in init including from any sublayers and their sublayers and so on.
@davisyoshida ha ok I see. I guess where I got confused is at the following example that you provide before (which helps me a lot with how checkpoint working but also give me a bit confused on how to train)
import tensorflow as tf
from checkpointing import checkpointable
@checkpointable
def f(x, y, some_str, some_bool, z=None):
for _ in range(200):
x += y * z
return x
initial = tf.ones(100000, dtype=tf.float32)
y = tf.ones(100000, dtype=tf.float32) + 1e-7
z = tf.ones(100000, dtype=tf.float32) + 1e-7
with tf.GradientTape() as g:
g.watch(initial)
x = initial
for _ in range(200):
x = f(x, y, 'a', True, z=z, _checkpoint=True)
loss = tf.reduce_sum(x)
print(g.gradient(loss, x))
so, inside with tf.GradientTape() as g:
, there are many variables there need to be trained, but we only watch the variable initial
. So, could you demonstrate that how you obtain all trainable variables here? The 'variable' x
is being rewritten all the time...
You can call it on the object you create. And you can and should define the model outside of the GradientTape context. Like so:
class MyLayer(Layer):
def __init__(self):
self.some_var = tf.Variable(3.0)
def call(self, inputs):
return self.some_var * inputs
def main():
layer = MyLayer()
checkpointable_layer = checkpointable(layer)
for example in dataset:
with tf.GradientTape() as g:
g.watch(layer.trainable_variables)
loss = tf.reduce_sum(checkpointable_layer(example, _checkpoint=True, _watch_vars=layer.trainable_variables))
Again I think this is less about the checkpointing script, and more about Keras, so if there's any more confusion you could check out some of the tutorials on the tensorflow site.
@davisyoshida thank you sir. Yea, I am new to keras (was a pure tf boy before). will try it. thanks a lot ^^
Yeah if you want to do it in pure TF you'll need to manage the variables yourself and pass them in to the _watch_vars arg. It's doable but might be tedious.
@davisyoshida right. that's part of the reason I'm trying to come to Keras, especially we I need to repeatedly use callable layers like the example below.
class MyLayer(keras.layers.Layer):
def __init__(self):
super(MyLayer, self).__init__()
self.dense = layers.Dense(32, activation='relu')
def call(self, inputs):
return self.dense(inputs)
def main():
layer = MyLayer()
checkpointable_layer = checkpointable(layer)
for example in dataset:
with tf.GradientTape() as g:
g.watch(layer.trainable_variables)
x = checkpointable_layer(example, _checkpoint=True, _watch_vars=layer.trainable_variables)
output = checkpointable_layer(x, _checkpoint=True, _watch_vars=layer.trainable_variables)
loss = tf.reduce_sum(output)
I know the above example won't work, as if I want to use MyLayers
multiple times (think MyLayer
is a DenseBlock, and in densenet there are many of them stack together), then I need to define multiple times of MyLayers
right? like MyLayers1
, MyLayers2
, and so on. and I ask g
to watch `layer1.trainable_variables + layer2.trainable_variables'
I think here is the part I got confused all the time... I understand that checkpoint should work on each of those MyLayers
to save memory (I would stack many dense layers inside each MyLayers
). Hence, I need to explicitly define each 'MyLayers', decorate it, put it inside GradientTape, not just one function and call it several times... This whole process is a bit tedious so I wondering is there sth I miss?
trainable_variables
captures variables from sublayers so you could do something like this:
class Network:
def __init__(self):
self.layers = []
for _ in range(5):
self.layers.append(keras.Dense(10))
self.checkpoint_layers = [checkpointable(l) for l in self.layers]
def call(self, inputs):
result = inputs
for layer in self.checkpoint_layers:
result = layer(result, _watch_vars=self.trainable_variables, _checkpoint=True)
return result
I c. will try. thank you very much!
@davisyoshida Hi. I tried play with Keras a bit and I finally come up with the following. Could you please have a look that whether I am using checkpointable correctly? c/z I am not getting memory saving as I expected.
input = keras.Input(shape=(16, 16, 16, 1))
class MyLayer(keras.Model):
def __init__(self, flts, kers):
super(MyLayer, self).__init__()
self.ops = [layers.Conv3D(flt, ker, use_bias=False, activation='relu', padding='same')
for flt, ker in zip(flts, kers)]
def call(self, x):
for layer in self.ops:
x = layer(x)
return x
ckpt = True
flts_ = (4,) * 6
kers_ = (3,) * 6
dense_layer1 = MyLayer(flts_, kers_)
dense_layer1.build((None, 16, 16, 16, 1))
checkpointable_dense1 = checkpointable(dense_layer1)
x = checkpointable_dense1(input, _watch_vars=dense_layer1.trainable_variables, _checkpoint=ckpt)
dense_layer2 = MyLayer(flts_, kers_)
dense_layer2.build(x.shape)
checkpointable_dense2 = checkpointable(dense_layer2)
x2 = checkpointable_dense2(x, _watch_vars=dense_layer2.trainable_variables, _checkpoint=ckpt)
output = tf.reduce_sum(x2)
model = keras.Model(input, output, name='toy')
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
# @tf.function
def train():
with tf.GradientTape() as tape:
loss = model(tf.ones([1, 16, 16, 16]))
grads = tape.gradient(loss, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
ts = time.time()
for epoch in tqdm(range(100)):
train()
print(f'training finished in {time.time()-ts}s')
Can you explain what procedure you're using to benchmark your memory usage?
I'm not sure if this will play nice with keras's keras.Model(input, output)
paradigm, or how well it will interact with tf.function.
Try dropping the model bit for now, and move the call logic directly inside your train loop. If you want to make everything be inside one class so you can use convenience functions/attributes, just make a layer and define an explicit call
function.
Can you explain what procedure you're using to benchmark your memory usage?
I usually do this
gpus = tf.config.experimental.list_physical_devices('GPU')
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
and watch nvidia-smi
to see how much gpu the program is using. I know this is not a perfect way but at least it gives me some idea.
Also, I usually first set _checkpoint=False
and slowly enlarger my model, for example, increase the input size bit by bit until OOM. Then, switch to _checkpoint=True
to see whether I could omit the OOM situation. If I can, then means the checkpoint worked.
@davisyoshida I changed to the following and it worked. I think keras.Model
warped those layers again and somehow canceled the checkpoint setting (I also experienced that Keras.Model
would produce error if used with tf.recompute_grad
.)
input = keras.Input(shape=(368, 368, 224, 1))
class MyLayer(keras.Model):
def __init__(self, flts, kers):
super(MyLayer, self).__init__()
self.ops = [layers.Conv3D(flt, ker, use_bias=False, activation='relu', padding='same')
for flt, ker in zip(flts, kers)]
def call(self, x):
for layer in self.ops:
x = layer(x)
return x
ckpt = True
flts_ = (4,) * 12
kers_ = (3,) * 12
layer1 = MyLayer(flts_, kers_)
layer1.build((None, 368, 368, 224, 1))
checkpointable_layer1 = checkpointable(layer1)
layer2 = MyLayer(flts_, kers_)
layer2.build((None, 368, 368, 224, 4))
checkpointable_layer2 = checkpointable(layer2)
layer3 = MyLayer(flts_, kers_)
layer3.build((None, 368, 368, 224, 4))
checkpointable_layer3 = checkpointable(layer3)
layer4 = MyLayer(flts_, kers_)
layer4.build((None, 368, 368, 224, 4))
checkpointable_layer4 = checkpointable(layer4)
layerf = MyLayer(flts_, kers_)
layerf.build((None, 368, 368, 224, 4))
checkpointable_layerf = checkpointable(layerf)
trainable_weights = \
layer1.trainable_variables + layer2.trainable_variables \
+ layer3.trainable_variables + layer4.trainable_variables + layerf.trainable_variables
initial = tf.ones([1, 368, 368, 224, 1])
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
for epoch in tqdm(range(100)):
with tf.GradientTape() as tape:
x1 = checkpointable_layer1(initial, _watch_vars=layer1.trainable_variables, _checkpoint=ckpt)
x2 = checkpointable_layer2(x1, _watch_vars=layer2.trainable_variables, _checkpoint=ckpt)
x3 = checkpointable_layer3(x2, _watch_vars=layer3.trainable_variables, _checkpoint=ckpt)
x4 = checkpointable_layer4(x3, _watch_vars=layer4.trainable_variables, _checkpoint=ckpt)
xf = checkpointable_layerf(x4, _watch_vars=layerf.trainable_variables, _checkpoint=ckpt)
output = tf.reduce_sum(xf)
# loss = model()
grads = tape.gradient(output, trainable_weights)
optimizer.apply_gradients(zip(grads, trainable_weights))
The above model would result in OOM
if set ckpt=False
, and operational if ckpt=True
. So I guess, implicitly, checkpointable
worked. So, it looks to me that those checkpointable
layers has to be placed directly under tf.GradientTape()
?
One last thing is, now we put everything under tf.GradientTape()
, how we save the model? The reason that I wrap everything with Keras.Model
again is to save it. I can sure save each checkpointable_layer
, but in the actual model that there gonna be many of them, so saving each would not be really practicable.
You definitely do not need to construct the layers under the GradientTape. I'm pretty sure stuff like tf.function and the keras.Model construction trace the code and then end up "compiling" it in a way which removes the gradient checkpointing. I've personally never constructed my layers inside the GradientTape context. Again, I would recommend explicitly writing a layer's call function, rather than using the implicit model construction. There's no reason that constructing the layers inside the GradientTape context would be necessary, as initializing variables isn't a differentiable operation.
Please see here for an example: https://github.com/davisyoshida/tf2-gradient-checkpointing/issues/2#issuecomment-803509770
trainable_variables
captures variables from sublayers so you could do something like this:class Network: def __init__(self): self.layers = [] for _ in range(5): self.layers.append(keras.Dense(10)) self.checkpoint_layers = [checkpointable(l) for l in self.layers] def call(self, inputs): result = inputs for layer in self.checkpoint_layers: result = layer(result, _watch_vars=self.trainable_variables, _checkpoint=True) return result
right! I think I understand more now. btw, in this example, u actually decorate every each layer but not a chunk of layers? I think I read in your git somewhere else that decorate every each layer won't save anything...
Yeah this example is just to show the architecture, replace each of the Dense
layers with some block of operations that it's actually worth checkpointing.
right here is what I did in the end, have one keras.Model
class (CkptLayer(keras.Model)
) which support call function wrapped by checkpoint
, and another one (class IntLayer(keras.Model)
) to integrate/use it and save models. Is this the correct way? I now can squeeze this model into 24GB GPU by setting ckpt=True
and would result in OOM
if ckpt=False
.
class CkptLayer(keras.Model):
def __init__(self, flts, kers):
super(CkptLayer, self).__init__()
self.ops = [layers.Conv3D(flt, ker, use_bias=False, activation='relu', padding='same')
for flt, ker in zip(flts, kers)]
# self.ckpt_layers = [checkpointable(l) for l in self.ops]
def call(self, ly, ckpt_on=False):
@checkpointable
def run_block(block, inputs):
layer_output = inputs
for layer in block:
layer_output = layer(layer_output)
return layer_output
ly = run_block(self.ops, ly, _checkpoint=ckpt_on, _watch_vars=self.trainable_variables)
return ly
ckpt = True
flts_ = (4,) * 12
kers_ = (3,) * 12
base_size = (368, 368, 224, 4)
build_size = (None,) + base_size
class IntLayer(keras.Model):
def __init__(self, CkptLayer, rpt):
super(IntLayer, self).__init__()
self.ops = [CkptLayer(flts_, kers_) for _ in range(rpt)]
self.init = [l_.build(build_size) for l_ in self.ops]
def call(self, ly, ckpt_on=False):
for op in self.ops:
ly = op(ly, ckpt_on=ckpt_on)
return ly
That works fine. It may be slightly slower than factoring run_block out into a separate sublayer, since then you could decorate it with tf.function. Right now, you can't do that since tf.function will be called every time you call the layer triggering a slow compiltaion.
Hi @davisyoshida . I was trying to implement this gradient-checkpointing tech so that I could fit in a larger size input of my model. Here is the toy example I tried:
The example give me a wired results. That is, the grad I obtained was changed (actually quit a lot) when I change _checkpoint from True to False, and also I didn't see any memory usage reduce as well.
Is there anything wrong with my implementation? Any help would be really welcome!