Describe the bug
Activity Regularizer not working with tf.function in TF2.7
Code
class TNN(tfk.Model):
def __init__(self, input_dim):
super(TNN, self).__init__(name='TNN')
self.fn_dense_input = tfk.layers.Dense(units=256, input_dim=input_dim)
self.fn_dense_hidden = tfk.layers.Dense(units=128)
self.fn_dense_output = tfk.layers.Dense(units=1, activation='sigmoid')
self.fn_reg_1 = tfk.layers.ActivityRegularization(l1=0.01, l2=0.01)
self.fn_reg_2 = tfk.layers.ActivityRegularization(l1=0, l2=0.01)
self.fn_bn_1 = tfk.layers.BatchNormalization()
self.fn_bn_2 = tfk.layers.BatchNormalization() # not reuse
self.fn_af = tfk.layers.Activation(activation='relu')
@tf.function
def call(self, inputs):
x = self.fn_dense_input(inputs)
x = self.fn_reg_1(x)
x = self.fn_bn_1(x)
x = self.fn_af(x)
x = self.fn_dense_hidden(x)
x = self.fn_reg_2(x)
x = self.fn_bn_2(x)
x = self.fn_af(x)
outputs = self.fn_dense_output(x)
return outputs
Error output
InaccessibleTensorError: tf.Graph captured an external symbolic tensor. The symbolic tensor <tf.Tensor 'activity_regularization/ActivityRegularizer/truediv:0' shape=() dtype=float32> is captured by FuncGraph(name=train_function, id=139760322273232), but it is defined at FuncGraph(name=call, id=139760322164240). A tf.Graph is not allowed to capture symoblic tensors from another graph. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.
What difference or How to use with tf.function
x = tfk.layers.Dense(units=256, kernel_regularizer=tfk.regularizers.l1_l2(l1=0.01, l2=0.01))(x)
# vs:
x = tfk.layers.Dense(units=256)(x)
x = tfk.layers.ActivityRegularization(l1=0.01, l2=0.01)(x)
This issue is more related to tensorflow. If possible, please post this issue on tensorflow/tensorflow repo as they can resolve this issue faster. Thanks!
Describe the bug Activity Regularizer not working with
tf.function
in TF2.7Code
Error output
What difference or How to use with
tf.function