tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.09k forks source link

Setting dtype in Categorical breaks autograph #1436

Open bryorsnef opened 3 years ago

bryorsnef commented 3 years ago

Currently working on something where it is easier to keep the output of a categorical distribution as a float32 instead of a int32, but this breaks autograph. Using tensorflow 2.5 and tfp 0.13.

import keras
import tensorflow as tf
import tensorflow_probability as tfp

def nll(y_true, y_pred):
  l = - y_pred.log_prob(y_true)
  return l

x = tf.random.normal((5,2))
y = tf.ones((5))

mod = keras.Sequential()
mod.add(keras.layers.Dense(units=20, activation='relu'))
mod.add(keras.layers.Dense(units=2))
mod.add(tfp.layers.DistributionLambda(lambda x: tfp.distributions.Categorical(logits=x))) 

mod.compile(loss=nll, optimizer="adam")
mod.fit(x, y, epochs = 20) ## no issue

x = tf.random.normal((5,2))
y = tf.ones((5),dtype="float32")

mod = keras.Sequential()
mod.add(keras.layers.Dense(units=20, activation='relu'))
mod.add(keras.layers.Dense(units=2))
mod.add(tfp.layers.DistributionLambda(lambda x: tfp.distributions.Categorical(logits=x, dtype = "float32")))

mod.compile(loss=nll, optimizer="adam")
mod.fit(x, y,  epochs = 20) ## breaks autograph

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
Frightera commented 3 years ago

Changing tfp.distributions.Categorical(logits=x, dtype = "float32") to tfp.distributions.Categorical(logits=x, dtype = tf.float32) worked fine for me.