tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.16k stars 1.08k forks source link

JAX backend doesn't use `jax.tree_util` #1785

Open krzysztofrusek opened 5 months ago

krzysztofrusek commented 5 months ago

In this part, expand_composites is set to False and this prevents jax.tree_util for kicking in. https://github.com/tensorflow/probability/blob/64a70d0850f2ab6b69693b7fb728b7c6eb759a7e/tensorflow_probability/python/internal/backend/numpy/nest.py#L314-L320

The resulting problem is that we cannot use pytree incompatible with dm_tree

Herę is notebook reproducing the bug https://colab.research.google.com/gist/krzysztofrusek/a9fa71ca2bf3952a9f18358309225107/eqx_tfp.ipynb

SiegeLordEx commented 4 months ago

expand_composites should probably be set to True in that location. There are other locations of similar nature where it is set to True (e.g. https://github.com/tensorflow/probability/blob/23a292a64f255fe7a98b32317d238b3e84f50c7f/tensorflow_probability/python/internal/loop_util.py#L183, https://github.com/tensorflow/probability/blob/23a292a64f255fe7a98b32317d238b3e84f50c7f/tensorflow_probability/python/mcmc/internal/util.py#L116), so I don't see an obvious reason why it shouldn't be True here as well.