Open krzysztofrusek opened 5 months ago
expand_composites
should probably be set to True
in that location. There are other locations of similar nature where it is set to True
(e.g. https://github.com/tensorflow/probability/blob/23a292a64f255fe7a98b32317d238b3e84f50c7f/tensorflow_probability/python/internal/loop_util.py#L183, https://github.com/tensorflow/probability/blob/23a292a64f255fe7a98b32317d238b3e84f50c7f/tensorflow_probability/python/mcmc/internal/util.py#L116), so I don't see an obvious reason why it shouldn't be True
here as well.
In this part,
expand_composites
is set toFalse
and this preventsjax.tree_util
for kicking in. https://github.com/tensorflow/probability/blob/64a70d0850f2ab6b69693b7fb728b7c6eb759a7e/tensorflow_probability/python/internal/backend/numpy/nest.py#L314-L320The resulting problem is that we cannot use pytree incompatible with
dm_tree
Herę is notebook reproducing the bug https://colab.research.google.com/gist/krzysztofrusek/a9fa71ca2bf3952a9f18358309225107/eqx_tfp.ipynb