keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
64 stars 30 forks source link

get_weights()/set_weights() take too long #551

Closed MatthewWiens101 closed 1 year ago

MatthewWiens101 commented 2 years ago

System information.

Describe the problem.

I am running some code which repeatedly (every training iteration) calls layer.get_weights() and layer.set_weights(). The callback operation containing these calls takes 0.009ms compared to the 0.003ms taken to run the batch and as such more than triples the training time required. I assume that this operation is simply moving tensors around (should be only on GPU) and thus should not take time comparable to the large matrix multiplications occurring during the batch iteration. I have reviewed the source code and to the best of my understanding this is what is happening. However, it is obviously taking an extraordinarily long time. Does anyone have any idea why this happens, or any approaches to reduce the time taken to call set_weights() and get_weights()? This abnormally long runtime may be due to the structure of the get_weights()/set_weights() functions, which is why I am raising this issue as a bug.

My intuition is that it may be due to data being sent to the CPU and back, or converted from tensors to numpy. Or, perhaps, upon calling set_weights, tensorflow rebuilds the entire graph from scratch or something similar.

One thing I noticed is that keras has their own pruning functionality shown here and this functionality incidentally also has a long callback runtime (see below). Perhaps this is related?

3/422 [..............................] - ETA: 12s - loss: 0.0628 - accuracy: 0.9896  
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0075s vs `on_train_batch_end` time: 0.0076s). Check your callbacks.

Describe the current behavior.

The callback to on_train_batch_end() in the code below calls get_weights() twice and set_weights() once, and takes twice as long to run as the batch update:

Epoch 1/40
  1/629 [..............................] - ETA: 0s - loss: 2.3555 - accuracy: 0.0000e+00
  WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0030s vs `on_train_batch_end` time: 0.0090s). Check your callbacks.

This is explicitly due to calling get_weights() and set_weights(), as their removal from the callback reduces runtime of the callback to negligible amounts.

Describe the expected behavior.

Ideally, I would like to achieve iterative magnitude pruning with the lowest possible runtime.

Standalone code to reproduce the issue.

import sys
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.datasets import mnist

### PRUNE WEIGHTS CALLBACK ###
class pruneModelCallback(Callback):
    def __init__(self, init_weight_dict=None, mask_dict=None):
        self.n_batches = 0
        self.init_weight_dict = init_weight_dict
        self.mask_dict = mask_dict

    def on_train_batch_begin(self, batch, logs=None):
        # save weights at initialization
        if self.n_batches == 0:
            if self.init_weight_dict is not None:
                for layer_i in range(len(self.model.layers)):
                    w = self.init_weight_dict['w_'+str(layer_i+1)]
                    b = self.init_weight_dict['b_'+str(layer_i+1)]

                    self.model.layers[layer_i].set_weights([w,b])
            else:
                self.init_weight_dict = {}
                for layer_i in range(len(self.model.layers)):
                    w = self.model.layers[layer_i].get_weights()[0]
                    b = self.model.layers[layer_i].get_weights()[1]

                    self.init_weight_dict['w_'+str(layer_i+1)] = w
                    self.init_weight_dict['b_'+str(layer_i+1)] = b

        self.n_batches = self.n_batches + 1

    # This is the problematic function, runs every training iteration batch
    def on_train_batch_end(self, batch, logs=None):
        # zero out pruned weights
        if self.mask_dict is not None:
            for layer_i in range(len(self.model.layers)):
                # removing these slightly improves runtime
                w = self.model.layers[layer_i].get_weights()[0]
                b = self.model.layers[layer_i].get_weights()[1]

                w_mask = self.mask_dict['w_'+str(layer_i+1)]

                # this multiplication takes no time comparably and removing it 
                # does not influence time taken
                w_pruned = w * w_mask

                # removing this function call significantly speeds up the runtime
                self.model.layers[layer_i].set_weights([w_pruned,b])

class pruneWeights():
    def __init__(self, model, percentile, pruning_type="IMP"):
        # generate pruned mask
        if pruning_type == "IMP":
            return self._IMP(model, percentile)
        else:
            raise ValueError("Unknown pruning_type {}".format(pruning_type))

    def _IMP(self, model, percentile):
        mask_dict = {}
        w_list = None
        for layer_i in range(len(model.layers)):
            w = model.layers[layer_i].get_weights()[0]
            w_shape = tf.shape(w)
            full_shape = tf.math.reduce_prod(w_shape)
            w_flat = tf.reshape(w, full_shape)
            if w_list is None:
                w_list = w_flat
            else:
                w_list = tf.concat([w_list, w_flat], axis=0)
        w_list = tf.math.abs(w_list)
        thresh = tfp.stats.percentile(w_list, percentile*100)
        test_mask = tf.cast(tf.math.greater(w_list, thresh), tf.float32)
        for layer_i in range(len(model.layers)):
            w = model.layers[layer_i].get_weights()[0]
            w = tf.math.abs(w)
            mask = tf.cast(tf.math.greater(w, thresh), tf.float32)
            mask_dict['w_'+str(layer_i+1)] = mask
        self.mask_dict = mask_dict

def pruning_breakdown(mask_dict, model):
    for layer_i in range(len(model.layers)):
        mask = mask_dict['w_'+str(layer_i+1)]
        print("w_"+str(layer_i+1)+": "+str(tf.math.reduce_mean(mask)))

def main():
    ### LOAD MNIST DATASET ###

    (x_train , y_train), (x_test , y_test) = mnist.load_data()
    x = np.concatenate((x_train, x_test), axis=0)
    y = np.concatenate((y_train, y_test), axis=0)
    x= x.astype("float32") / 255
    x= np.reshape(x, (np.shape(x)[0], 784))
    scaler = StandardScaler(with_std=False)
    scaler.fit(x)
    x_t= scaler.transform(x)
    ohe = OneHotEncoder()
    y_t= ohe.fit_transform(y.reshape(-1, 1)).toarray()
    input_dim = np.shape(x_t)[1:]
    output_dim = np.shape(y_t)[1:]

    del x_train, x_test, y_train, y_test, x, y

    ### SPLIT TRAINING DATA ###

    X_train, X_test, y_train, y_test = train_test_split(x_t, y_t, test_size=0.33, random_state=42)

    idxs = tf.range(tf.shape(X_train)[0])

    ### MODEL INIT ###

    model = Sequential([
        Dense(300, input_dim=input_dim[0], activation='relu'),
        Dense(100, activation='relu'),
        Dense(50, activation='relu'),
        Dense(output_dim[0], activation='softmax')
    ])

    model.compile(
        optimizer = keras.optimizers.Adam(lr=1.2e-4),
        loss = tf.keras.losses.CategoricalCrossentropy(),
        metrics = ['accuracy']
    )

    ### TRAIN MODEL ###

    epochs = 20

    pr = pruneModelCallback()
    es = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True,)

    history = model.fit(
        x=X_train,
        y=y_train,
        batch_size=64,
        epochs=epochs,
        validation_data=(X_test,y_test),
        callbacks = [pr, es],
    )

    ### ITERATIVELY PRUNE MODEL ###

    percentile_per_it = 0.20
    it = 14

    for i in range(it):
        total_percentile = 1-tf.math.pow(1-percentile_per_it, i+1)
        print(total_percentile)
        pruned_weights = pruneWeights(model, total_percentile)
        pruning_breakdown(pruned_weights.mask_dict, model)

        model = Sequential([
            Dense(300, input_dim=input_dim[0], activation='relu'),
            Dense(100, activation='relu'),
            Dense(50, activation='relu'),
            Dense(output_dim[0], activation='softmax')
        ])

        model.compile(
            optimizer = keras.optimizers.Adam(lr=1.2e-4),
            loss = tf.keras.losses.CategoricalCrossentropy(),
            metrics = ['accuracy']
        )

        ### TRAIN MODEL ###

        epochs = 40

        pr_next = pruneModelCallback(init_weight_dict=pr.init_weight_dict, mask_dict=pruned_weights.mask_dict)
        es = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True,)

        history = model.fit(
            x=X_train,
            y=y_train,
            batch_size=64,
            epochs=epochs,
            validation_data=(X_test,y_test),
            callbacks = [pr_next, es],
        )

        pr = pr_next

if __name__ == "__main__":
    main()
sushreebarsa commented 2 years ago

@MatthewWiens101 TF v2.3 is not actively supported, we recommend you to kindly upgrade to latest TF version. I tried to replicate this issue on colab, could you please find the gist here and confirm the same? Thank you!

MatthewWiens101 commented 2 years ago

@sushreebarsa Sorry, there were some typos in the shared code. I have updated the gist here and it is running fine in TF v2.8. It is still producing the issue with the warning:

5/733 [..............................] - ETA: 11s - loss: 1.4617 - accuracy: 0.7594 
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0024s vs `on_train_batch_end` time: 0.0122s). Check your callbacks.
google-ml-butler[bot] commented 2 years ago

Are you satisfied with the resolution of your issue? Yes No

sushreebarsa commented 2 years ago

@MatthewWiens101 Could you have look at this gist and let us know if it is the current behaviour of the issue reported ? Thank you!

MatthewWiens101 commented 2 years ago

@sushreebarsa Yes the behavior seen in that gist matches the issue I have reported.

MatthewWiens101 commented 2 years ago

@sachinprasadhs any update on this issue? Still looking for a faster workaround.

azd-rzzd commented 1 year ago

I have the same issue with tf.GradientTape(). I use this to watch the gradients on_epoch_end and each iteration takes around 50 minutes, while training itself is less than 5 minutes.

mattdangerw commented 1 year ago

Looking at the code here, it seems like you have a mask_dict which is static in the context of an individual model.fit() call. Is that right?

If that is the case, you would probably see much better performance by making a custom layer called MaskedDense perhaps, that implements the logic you want here, and passing the static mask to that layer. Hard to say exactly what the right structure would be without digging more into the use case, but the overall goal should be to remove the on_train_batch_end and make simple layers that do everything you need inside of call.

In general, Keras will achieve best performance with your model when compiling everything into a tf.function. This guide might be a useful reference. You don't need to do anything fancy to get this working with Keras. Just make a model, compile it as normal, and you are running with tf.function under the hood.

However, attempting to override the weights for every layer eagerly between each train step of your model (the compiled fast part), will be way slower than brining this w_mask logic into the actual compiled train step of your model.

Hope that helps!

google-ml-butler[bot] commented 1 year ago

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

google-ml-butler[bot] commented 1 year ago

Closing as stale. Please reopen if you'd like to work on this further.

google-ml-butler[bot] commented 1 year ago

Are you satisfied with the resolution of your issue? Yes No

mvjwiens-libang commented 1 year ago

@mattdangerw thanks for your reply and suggestions. With more experience now working with keras, I definitely agree with you that a custom layer for managing masked weights would work much better. The approach I provided was a quick workaround at the time, and complimented some other querying I was performing on the weights. As to how exactly to prevent the gradient from updating certain weights while keeping them within the weight matrix, I will do some digging and post a solution if I come up with one.

I should also mention, some of that "other querying" I mentioned is very similar to that of @azd-rzzd in their reply. Namely, every nth iteration I grab a copy of the gradient from gradient tape and do some statistical analysis. I do this only so often as it is quite computationally expensive. I haven't done any digging to see if this is due to grabbing the gradients or the statistics, but it would be nice to know if there is an easy way to grab the gradients from a batch, or if they disappear in every intermediary step. I have attached an example below (copied and trimmed, untested):

class GetGradients(Callback):
    def __init__(self, x_train, y_train, n_steps=1):
        self.batch_size = tf.shape(x_train)[0]
        self.n_batches = 0
        self.n_steps = n_steps
        self.x_train = tf.cast(x_train, tf.float32)
        self.y_train = y_train
        if len(tf.shape(self.x_train)) == 4:
            self.x_train_res = self.x_train - tf.expand_dims(tf.expand_dims(tf.expand_dims(tf.math.reduce_mean(self.x_train, axis=[0,1,2]), axis=0), axis=0), axis=0)
        elif len(tf.shape(self.x_train)) == 2:
            self.x_train_res = self.x_train - tf.expand_dims(tf.math.reduce_mean(self.x_train, axis=0), axis=0)
        else:
            raise ValueError("x_train has shape of length {}, should be either 2 (dense input) or 4 (image input)".format(len(tf.shape(self.x_train))))

    def on_train_batch_begin(self, batch, logs=None):
        # will only work on sequential models with 2d convolutions or dense layers
        self.n_batches = self.n_batches + 1
        if not ((self.n_batches-1) % self.n_steps == 0):
            return
        with tf.device('/gpu:0'):
            with tf.GradientTape(persistent=True) as tape:
                tape.watch(self.x_train)
                out_intermediate = []
                cargo = self.model.layers[0](self.x_train)
                tape.watch(cargo)
                out_intermediate.append(cargo)
                for layer in self.model.layers[1:]:
                    cargo = layer(cargo)
                    tape.watch(cargo)
                    out_intermediate.append(cargo)
                loss = self.model.loss(self.y_train, cargo)

            del cargo

            # first layer #
            prev_layer_res = self.x_train_res
            grad = tape.gradient(loss, self.x_train)
            # perform some analysis/saving here

            del grad

            for layer in range(len(self.model.layers)):
                grad = tape.gradient(loss, out_intermediate[layer])
                # perform some analysis/saving here

                del grad
mvjwiens-libang commented 1 year ago

It would also be nice to understand more about how get_weights/set_weights being called eagerly in between computation cycles takes so long. Does it have to do with the graph, or where the memory is stored? If you have any recommended reading or insight, it would be greatly appreciated.

As I mentioned in my original post, I find the question of long runtimes when pruning in keras quite interesting, especially as keras' own pruning functionality is affected by slow "on_train_batch_end" performance.