tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

API error in jax substrate? #1709

Open justindomke opened 1 year ago

justindomke commented 1 year ago

Hello,

If I install tensorflow probability and try to follow the instructions to create a realNVP flow, I get an import error. (This is replicated across several different installs and operating systems, including linux, mac, and google colab, and python 3.9 and 3.10. It happens even in a fresh conda environment where all that's been installed is tensorflow-probability)

The error can be reproduced as follows:

from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijectors
tfb.real_nvp_default_template(hidden_layers=[512, 512])

This is the smallest reproducible error of trying to run the example from here: https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/RealNVP

csuter commented 1 year ago

To implement the jax substrate, we basically reimplement the TF API surface using jax under the hood. This lets us leave most TFP code intact, just swapping out imports. This is not something we've done for the Keras portions of the TF API surface, so there is no underlying implementation for things like the real_nvp default MLP. You'd need to provide an alternate implementation using flax or something. I feel like there are examples floating around; I'll try to find one.

justindomke commented 1 year ago

Thank you, that would be super helpful! (Honestly any example of how to use realNVP from JAX would be helpful—I don't mind writing my own alternative template, but I'm not sure what that would involve.)

brianwa84 commented 1 year ago

Hi Justin, I put together a quick gist here: https://colab.research.google.com/gist/brianwa84/dfa3d56cded8e56038184fb17048afc6/rnvp-jax.ipynb Hopefully that's enough to get you going. LMK if you have questions.

jamesheald commented 3 months ago

Thanks for the colab @brianwa84 . I found it very helpful.

If I understand correctly, the shift_and_log_scale_fn expects two arguments: the input x , and the number of output variables. It's not clear from the documentation that this second variable is expected, is it?