tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

MultivariateNormalTriL Layer appears to be incompatible with tf.keras in tf 2.16.1 and tfp 0.24 #1809

Open reichardtj opened 2 months ago

reichardtj commented 2 months ago

There appears to be a breaking change in the way MultivariateNormalTriL works together with tf.keras in tf 2.16.1 and tfp 0.24.0, tf_keras version 2.16.0

I'm using python 3.11.8 on a Mac M3, but can reproduce the issue also on a Linux Machine. I did not try a different python version.

conda create -n p311 python=3.11
conda activate p311
pip install tensorflow
pip install tensorflow-probability
pip install tf_keras

The last version I tested where this problem does not occur is tf 2.14.0 and tfp 0.22.0 with python 3.10.13 - I did not test intermediate versions.

Here is a minimal example to reproduce the issue - it simply implements the example from the documentation (https://www.tensorflow.org/probability/api_docs/python/tfp/layers/MultivariateNormalTriL)

import tensorflow as tf
import tensorflow_probability as tfp

print('tf: ', tf.__version__, 'tfp:', tfp.__version__)

# Create model.
d = 5
dist_param_dim = tfp.layers.MultivariateNormalTriL.params_size(d)
model = tf.keras.Sequential([
    tf.keras.layers.Dense(dist_param_dim),
    tfp.layers.MultivariateNormalTriL(d),
])

This raises

ValueError: Only instances of `keras.Layer` can be added to a Sequential model.

Received: <tensorflow_probability.python.layers.distribution_layer.MultivariateNormalTriL object at 0x380069b90> 
(of type <class 'tensorflow_probability.python.layers.distribution_layer.MultivariateNormalTriL'>)

Using the function API from keras results in a related but different error:

inp = tf.keras.layers.Input(shape=(5,))

x = tf.keras.layers.Dense(dist_param_dim)(inp)
output = tfp.layers.MultivariateNormalTriL(d)(x)

model = tf.keras.Model(inputs=inp, outputs=output)

This raises

ValueError: Exception encountered when calling layer 'multivariate_normal_tri_l_3' (type MultivariateNormalTriL).

A KerasTensor cannot be used as input to a TensorFlow function. 
A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces `keras.layers` and `keras.operations`). You are likely doing something like:

x = Input(...)
...
tf_fn(x)  # Invalid.

What you should do instead is wrap `tf_fn` in a layer:

class MyLayer(Layer):
    def call(self, x):
        return tf_fn(x)

x = MyLayer()(x)

Call arguments received by layer 'multivariate_normal_tri_l_3' (type MultivariateNormalTriL):
  • inputs=<KerasTensor shape=(None, 20), dtype=float32, sparse=False, name=keras_tensor_15>
  • args=<class 'inspect._empty'>
  • kwargs={'training': 'None'}

Maybe that helps identifying the problem. I appears that MultivariateNormalTriL is not recognised as a layer anymore, but as a distribution function.

Using the example from https://www.tensorflow.org/probability/api_docs/python/tfp/layers/DistributionLambda seems to support this assumption. It, too, raises the error ValueError: Only instances of keras.Layer can be added to a Sequential model..

Any help is greatly appreciated!

reichardtj commented 2 months ago

It appears that simply

import tf_keras as tfk

and using tfk.* instead of tf.keras.* fixes the issue, but note that all documentation pages cited refer to tf.keras.*

So maybe that's suggestion for changing the documentation then.

jburnim commented 2 months ago

Thanks for the report!

I believe we have changed all of our documentation/examples in the TFP codebase to use tf_keras instead of tf.keras. But it doesn't look like these changes have been propagated to tensorflow.org -- e.g., I still see tf.keras at https://www.tensorflow.org/probability/api_docs/python/tfp/layers/DistributionLambda .

I'll investigate to see why the documentation at tensorflow.org has not been updated.