tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

Keras VQVAE example no longer works with tfp v0.14.0+ #1576

Closed reinvantveer closed 2 years ago

reinvantveer commented 2 years ago

Dear devs,

The example at https://keras.io/examples/generative/vq_vae/ works only with tfp up until v0.13.0. https://github.com/keras-team/keras-io/pull/911 saw a resolution of this issue, but the question François asked me is how the code should be refactored to work with newer versions. Pinning tfp to 0.13.0 is the easiest fix, but isn't the prettiest.

If I understand correctly, the 'offending' snippet involves building a model with a tfp sampling layer inside:

# Create a mini sampler model.
inputs = layers.Input(shape=pixel_cnn.input_shape[1:])
x = pixel_cnn(inputs, training=False)
dist = tfp.distributions.Categorical(logits=x)
sampled = dist.sample()
sampler = keras.Model(inputs, sampled)

This fails at sampled = dist.sample() with: TypeError: Dimension value must be integer or None or have an __index__ method, ....

How should this be resolved in 0.14.0+? Thanks in advance

csuter commented 2 years ago

When I modify that cell to print x, its type is KerasTensor. I'm not sure how this relates to tf.Tensor, but TFP does not (explicitly) support KerasTensors as inputs. The docstring on KerasTensor says "A representation of a Keras in/output during Functional API construction", which makes it sound like this is not what should be returned from the pixel_cnn(...) call (I'd expect a tf.Tensor, tf.EagerTensor, or simply a np.ndarray here, though I admit I don't understand Keras well).

Is this return type what you'd expect here?

csuter commented 2 years ago

Oh I misunderstood the code block. I guess it's building a new model as an extension of the pixel_cnn.

reinvantveer commented 2 years ago

Yes exactly, the pixelcnn is a layer in the model that returns a tf.Tensor when called. Question could be what was dropped in tfp 0.14.0+ that it no longer accepts keras layers and how to work around this?

csuter commented 2 years ago

I think the relevant change from 0.13->0.14 was the addition a bunch of batch shape inference machinery in TFP. This is likely written with a heavy bias towards assuming that inputs are tf.Tensors (or duck-type as tf.Tensors, well enough).

Replacing the explicit distribution construction and sample with a tfp.layers.DistributionLambda appears to work. These layers were designed to work with Keras.

Working snippet:

# Create a mini sampler model.
inputs = layers.Input(shape=pixel_cnn.input_shape[1:])
x = pixel_cnn(inputs, training=False)
cat_layer = tfp.layers.DistributionLambda(tfp.distributions.Categorical)
sampled = cat_layer(x)
sampler = keras.Model(inputs, sampled)

DistributionLambda takes as input a function that produces a distribution (in this case the Categorical constructor will suffice because its only required positional argument is logits), as well as an optional convert_to_tensor_fn function, which takes a distribution and produces a tensor (this could be mean or some quantile or whatever, but the default is actually sample, which is what this snippet wants). We could write the above a bit more explicitly as

# Create a mini sampler model.
inputs = layers.Input(shape=pixel_cnn.input_shape[1:])
x = pixel_cnn(inputs, training=False)
cat_layer = tfp.layers.DistributionLambda(
    make_distribution_fn=lambda x: tfp.distributions.Categorical(logits=x),
    convert_to_tensor_fn=lambda dist: dist.sample())
sampled = cat_layer(x)
sampler = keras.Model(inputs, sampled)

Anyway, with that change the rest of the colab runs fine. HTH!

reinvantveer commented 2 years ago

@csuter Thanks!!! I'll see that I can submit a new PR to Keras with this implementation.

csuter commented 2 years ago

Excellent, glad this is helpful. It's possible we could make the batch shape inference more robust to Keras inputs, but this would be more work.