Open gileshd opened 1 year ago
Bijector code that uses keras from TF is not going to work well in JAX
(other examples would be Glow and PixelCNN). You can look at
masked_autoregressive_test.py to see which tests are disabled with JAX --
it's many of them.
But I think you could probably use the bijector as part of a flax
module, possibly even as the return value of call
.
If you wanted to contribute some kind of stateless AutoregressiveNetwork for JAX I think it would be a nice PR.
Brian Patton | Software Engineer | @.***
On Tue, Feb 28, 2023 at 9:12 AM Giles Harper-Donnelly < @.***> wrote:
Trying to initialize an instance of tfb.AutoregressiveNetwork using the jax substrate fails with an AttributeError.
With the example usage from the docs https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors/AutoregressiveNetwork :
from tensorflow_probability.substrates import jax as tfptfb = tfp.bijector tfb.AutoregressiveNetwork(params=2, hidden_units=[10,10])
raises the error:
AttributeError: module 'tensorflow_probability.python.internal.backend.jax.numpy_keras' has no attribute 'constraints'
Tested with: tfp version: 0.19.0 jax version: 0.4.4 and 0.3.25
— Reply to this email directly, view it on GitHub https://github.com/tensorflow/probability/issues/1699, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSIZTRJ2KVTK75I7HI7TWZYBTBANCNFSM6AAAAAAVKY3BJI . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Trying to initialize an instance of
tfb.AutoregressiveNetwork
using the jax substrate fails with an AttributeError.With the example usage from the docs:
raises the error:
Tested with: tfp version: 0.19.0 jax version: 0.4.4 and 0.3.25