tensorlayer / TensorLayer

Deep Learning and Reinforcement Learning Library for Scientists and Engineers
http://tensorlayerx.com
Other
7.31k stars 1.61k forks source link

Problem with the 2nd order derivative using TL activations #1147

Open seraph0012 opened 2 years ago

seraph0012 commented 2 years ago

Hi! I'm working on a model where I need to calculate the 2nd order derivative of the output with respect to some of the input variables. If I use the existing activations in tensorflow (e.g. tf.nn.tanh) the 2nd order derivatives would be calculated correctly. However, if I replace the tensorflow activations with any of the tensorlayer activations, the 1st order derivative will still be calculated, but the 2nd order derivative would just be None.

I'm using Python 3.8.10, Tensorflow 2.5.0, and Tensorlayer 2.2.3. Here's the code to reproduce the problem:

layers = [2, 10, 10, 10, 1]
inputs = tf.keras.layers.Input(shape=(layers[0],))
x = inputs
for width in layers[1:-1]:
    x = Dense(width,
              activation=tensorlayer.activation.leaky_relu,
              kernel_initializer='glorot_normal')(x)
outputs = Dense(layers[-1])(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
data = tf.convert_to_tensor(np.random.rand(10, 2))
with tf.GradientTape(persistent=True) as tape:
    tape.watch(data)
    y = model(data)
    y_x = tape.gradient(y, data)
y_xx = tape.gradient(y_x, data)
print(y_xx)

This will print None. If I replace tensorlayer.activation.leaky_relu with tf.nn.leaky.relu it then prints some numbers. I would like to use hard_tanh in my model which tensorflow doesn't have. However the 2nd order derivative calculation is not working properly for TL activations.