tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.09k forks source link

Error initializing tfb.AutoregressiveNetwork using jax substrate #1699

Open gileshd opened 1 year ago

gileshd commented 1 year ago

Trying to initialize an instance of tfb.AutoregressiveNetwork using the jax substrate fails with an AttributeError.

With the example usage from the docs:

from tensorflow_probability.substrates import jax as tfp
tfb = tfp.bijector

tfb.AutoregressiveNetwork(params=2, hidden_units=[10,10])

raises the error:

File .../site-packages/tensorflow_probability/substrates/jax/bijectors/masked_autoregressive.py:967, in AutoregressiveNetwork.__init__(self, params, event_shape, conditional, conditional_event_shape, conditional_input_layers, hidden_units, input_order, hidden_degrees, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, kernel_constraint, bias_constraint, validate_args, **kwargs)
    965 self._kernel_regularizer = kernel_regularizer
    966 self._bias_regularizer = bias_regularizer
--> 967 self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
    968 self._bias_constraint = bias_constraint
    969 self._validate_args = validate_args

AttributeError: module 'tensorflow_probability.python.internal.backend.jax.numpy_keras' has no attribute 'constraints'

Tested with: tfp version: 0.19.0 jax version: 0.4.4 and 0.3.25

brianwa84 commented 1 year ago

Bijector code that uses keras from TF is not going to work well in JAX (other examples would be Glow and PixelCNN). You can look at masked_autoregressive_test.py to see which tests are disabled with JAX -- it's many of them. But I think you could probably use the bijector as part of a flax module, possibly even as the return value of call.

If you wanted to contribute some kind of stateless AutoregressiveNetwork for JAX I think it would be a nice PR.

Brian Patton | Software Engineer | @.***

On Tue, Feb 28, 2023 at 9:12 AM Giles Harper-Donnelly < @.***> wrote:

Trying to initialize an instance of tfb.AutoregressiveNetwork using the jax substrate fails with an AttributeError.

With the example usage from the docs https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/AutoregressiveNetwork :

from tensorflow_probability.substrates import jax as tfptfb = tfp.bijector tfb.AutoregressiveNetwork(params=2, hidden_units=[10,10])

raises the error:

AttributeError: module 'tensorflow_probability.python.internal.backend.jax.numpy_keras' has no attribute 'constraints'

Tested with: tfp version: 0.19.0 jax version: 0.4.4 and 0.3.25

— Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1699, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSIZTRJ2KVTK75I7HI7TWZYBTBANCNFSM6AAAAAAVKY3BJI . You are receiving this because you are subscribed to this thread.Message ID: @.***>