tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.25k stars 1.1k forks source link

WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. #1606

Open bobosfw opened 2 years ago

bobosfw commented 2 years ago

When I try to train a very simple and small model with tensorflow and tensorflow_probability, the code is following

def neg_log_likelihood(y_true, y_hat): return -y_hat.log_prob(y_true)

inputs = tf.keras.layers.Input(shape=(X_train.shape[1], )) h1 = tf.keras.layers.BatchNormalization()(inputs) h2 = tf.keras.layers.Dense(4, activation=tf.nn.relu, name="layer_1")(h1) rate = tf.keras.layers.Dense(1+1, activation=tf.exp, name="layer_2")(h2) y = tfp.layers.DistributionLambda(lambda t: tfd.Skellam(rate1=t[..., 0:1], rate2=t[..., 1:]), name="prob")(rate)

model = tf.keras.Model(inputs=inputs, outputs=y) model.compile(tf.keras.optimizers.Adam(learning_rate=0.01), loss=neg_log_likelihood) history = model.fit(x=X_train, y=y_train, validation_data=(X_val, y_val), epochs=100, batch_size=128, shuffle=True)

When I start training the model, I get a tensorflow warning WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass.

I ignore the WARNING and continue to train the model. However I find the model training is very slow. The ETA is almost 2 min per epoch. The shape of the training set is (121562, 24)

So how can I fix the problem? Thanks

tensorflow version--2.8.0 tensorflow_probability version--0.16.0

jedisom commented 2 years ago

It seems to have something to do with the Distribution being used in the tfp.layers.DistributionLambda. See my similar issue here: https://github.com/tensorflow/probability/issues/1626 with the Beta class. Other distribution classes seem to work fine.

jedisom commented 1 year ago

@bobosfw, have you tried installing the latest nightly builds like this: pip install -U tf-nightly tfp-nightly. The lastest nightly build resolved the error I was seeing on issue 1626