Closed MatthewWiens101 closed 1 year ago
@MatthewWiens101 TF v2.3 is not actively supported, we recommend you to kindly upgrade to latest TF version. I tried to replicate this issue on colab, could you please find the gist here and confirm the same? Thank you!
@sushreebarsa Sorry, there were some typos in the shared code. I have updated the gist here and it is running fine in TF v2.8. It is still producing the issue with the warning:
5/733 [..............................] - ETA: 11s - loss: 1.4617 - accuracy: 0.7594
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0024s vs `on_train_batch_end` time: 0.0122s). Check your callbacks.
@MatthewWiens101 Could you have look at this gist and let us know if it is the current behaviour of the issue reported ? Thank you!
@sushreebarsa Yes the behavior seen in that gist matches the issue I have reported.
@sachinprasadhs any update on this issue? Still looking for a faster workaround.
I have the same issue with tf.GradientTape(). I use this to watch the gradients on_epoch_end and each iteration takes around 50 minutes, while training itself is less than 5 minutes.
Looking at the code here, it seems like you have a mask_dict
which is static in the context of an individual model.fit()
call. Is that right?
If that is the case, you would probably see much better performance by making a custom layer called MaskedDense
perhaps, that implements the logic you want here, and passing the static mask to that layer. Hard to say exactly what the right structure would be without digging more into the use case, but the overall goal should be to remove the on_train_batch_end
and make simple layers that do everything you need inside of call
.
In general, Keras will achieve best performance with your model when compiling everything into a tf.function
. This guide might be a useful reference. You don't need to do anything fancy to get this working with Keras. Just make a model, compile it as normal, and you are running with tf.function
under the hood.
However, attempting to override the weights for every layer eagerly between each train step of your model (the compiled fast part), will be way slower than brining this w_mask
logic into the actual compiled train step of your model.
Hope that helps!
This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.
Closing as stale. Please reopen if you'd like to work on this further.
@mattdangerw thanks for your reply and suggestions. With more experience now working with keras, I definitely agree with you that a custom layer for managing masked weights would work much better. The approach I provided was a quick workaround at the time, and complimented some other querying I was performing on the weights. As to how exactly to prevent the gradient from updating certain weights while keeping them within the weight matrix, I will do some digging and post a solution if I come up with one.
I should also mention, some of that "other querying" I mentioned is very similar to that of @azd-rzzd in their reply. Namely, every nth iteration I grab a copy of the gradient from gradient tape and do some statistical analysis. I do this only so often as it is quite computationally expensive. I haven't done any digging to see if this is due to grabbing the gradients or the statistics, but it would be nice to know if there is an easy way to grab the gradients from a batch, or if they disappear in every intermediary step. I have attached an example below (copied and trimmed, untested):
class GetGradients(Callback):
def __init__(self, x_train, y_train, n_steps=1):
self.batch_size = tf.shape(x_train)[0]
self.n_batches = 0
self.n_steps = n_steps
self.x_train = tf.cast(x_train, tf.float32)
self.y_train = y_train
if len(tf.shape(self.x_train)) == 4:
self.x_train_res = self.x_train - tf.expand_dims(tf.expand_dims(tf.expand_dims(tf.math.reduce_mean(self.x_train, axis=[0,1,2]), axis=0), axis=0), axis=0)
elif len(tf.shape(self.x_train)) == 2:
self.x_train_res = self.x_train - tf.expand_dims(tf.math.reduce_mean(self.x_train, axis=0), axis=0)
else:
raise ValueError("x_train has shape of length {}, should be either 2 (dense input) or 4 (image input)".format(len(tf.shape(self.x_train))))
def on_train_batch_begin(self, batch, logs=None):
# will only work on sequential models with 2d convolutions or dense layers
self.n_batches = self.n_batches + 1
if not ((self.n_batches-1) % self.n_steps == 0):
return
with tf.device('/gpu:0'):
with tf.GradientTape(persistent=True) as tape:
tape.watch(self.x_train)
out_intermediate = []
cargo = self.model.layers[0](self.x_train)
tape.watch(cargo)
out_intermediate.append(cargo)
for layer in self.model.layers[1:]:
cargo = layer(cargo)
tape.watch(cargo)
out_intermediate.append(cargo)
loss = self.model.loss(self.y_train, cargo)
del cargo
# first layer #
prev_layer_res = self.x_train_res
grad = tape.gradient(loss, self.x_train)
# perform some analysis/saving here
del grad
for layer in range(len(self.model.layers)):
grad = tape.gradient(loss, out_intermediate[layer])
# perform some analysis/saving here
del grad
It would also be nice to understand more about how get_weights/set_weights being called eagerly in between computation cycles takes so long. Does it have to do with the graph, or where the memory is stored? If you have any recommended reading or insight, it would be greatly appreciated.
As I mentioned in my original post, I find the question of long runtimes when pruning in keras quite interesting, especially as keras' own pruning functionality is affected by slow "on_train_batch_end" performance.
System information.
Describe the problem.
I am running some code which repeatedly (every training iteration) calls layer.get_weights() and layer.set_weights(). The callback operation containing these calls takes 0.009ms compared to the 0.003ms taken to run the batch and as such more than triples the training time required. I assume that this operation is simply moving tensors around (should be only on GPU) and thus should not take time comparable to the large matrix multiplications occurring during the batch iteration. I have reviewed the source code and to the best of my understanding this is what is happening. However, it is obviously taking an extraordinarily long time. Does anyone have any idea why this happens, or any approaches to reduce the time taken to call set_weights() and get_weights()? This abnormally long runtime may be due to the structure of the get_weights()/set_weights() functions, which is why I am raising this issue as a bug.
My intuition is that it may be due to data being sent to the CPU and back, or converted from tensors to numpy. Or, perhaps, upon calling set_weights, tensorflow rebuilds the entire graph from scratch or something similar.
One thing I noticed is that keras has their own pruning functionality shown here and this functionality incidentally also has a long callback runtime (see below). Perhaps this is related?
Describe the current behavior.
The callback to on_train_batch_end() in the code below calls get_weights() twice and set_weights() once, and takes twice as long to run as the batch update:
This is explicitly due to calling get_weights() and set_weights(), as their removal from the callback reduces runtime of the callback to negligible amounts.
Describe the expected behavior.
Ideally, I would like to achieve iterative magnitude pruning with the lowest possible runtime.
Standalone code to reproduce the issue.