tensorflow / model-optimization

A toolkit to optimize ML models for deployment for Keras and TensorFlow, including quantization and pruning.
https://www.tensorflow.org/model_optimization
Apache License 2.0
1.49k stars 321 forks source link

How to prune a custom tensor? The tensor is a recursive variable and is initialized with tf.zeros. #875

Open starsky68 opened 3 years ago

starsky68 commented 3 years ago

Prior to filing: check that this should be a bug instead of a feature request. Everything supported, including the compatible versions of TensorFlow, is listed in the overview page of each technique. For example, the overview page of quantization-aware training is here. An issue for anything not supported should be a feature request.

Describe the bug How to prune a custom tensor? The tensor is a custom variable and is initialized with tf.zeros.

System information

TensorFlow version (installed from source or binary):

TensorFlow Model Optimization version (installed from source or binary):

Python version: 3.8

Describe the expected behavior

Describe the current behavior

How to prune "b"

Code to reproduce the issue

class PruningLayer(tf.keras.layers.Layer, tfmot.sparsity.keras.PrunableLayer):

def __init__(self, n, d):
    super(PruningLayer, self).__init__()

    self.n = n
    self.d = d
def build(self, input_shape):
    self.weight = self.add_weight("weight", shape=[1, input_shape[1],
                                                   self.n,
                                                   self.d,
                                                   input_shape[2]
                                                   ],
                                  initializer="random_normal",
                                  trainable=True)

def call(self, x):
    u = tf.matmul(self.weight, x)
    b = self.Rr(u)
    s = tf.multiply(x, b)
    return s

def get_prunable_weights(self):
    return [self.weight]

def Rr(self, x):
    input_shape = tf.shape(x)
    # initialize b to zero
    b = tf.zeros((input_shape[0], input_shape[1], self.n, 1))

    for _ in range(3):
        c = tf.nn.softmax(b, axis=2)
        b = b + tf.multiply(x, c)
    return b

Screenshots If applicable, add screenshots to help explain your problem.

Additional context

fredrec commented 2 years ago

Hi @starsky68,

Can you please provide more context ? In particular:

starsky68 commented 2 years ago

Hi @starsky68,

Can you please provide more context ? In particular:

  • What is self.weight for ? It looks like it is never used in the layer.
  • What prevents b to be a class member ? You could then return it in get_prunable_weights(self)

I modified the above sample code again. Self. Weight can be obtained through get prunable Weights returns, but I don't know if this 'b' is returned to get prunable Weights, hope to get help. Where ‘b' is an iteratively updated tensor

starsky68 commented 2 years ago

Hi @starsky68,

Can you please provide more context ? In particular:

  • What is self.weight for ? It looks like it is never used in the layer.
  • What prevents b to be a class member ? You could then return it in get_prunable_weights(self)

When I use TF1, I can directly use its pruning interface apply_ mask operates on the tensor, but the current interface seems to have changed after TF2. Such operations are no longer supported