Closed peterroelants closed 1 year ago
It seems to me that if I want to fix this I need to somehow map over PyTreeDefs, I created a question at the JAX Github here: https://github.com/google/jax/discussions/13768
I don't think you can collectively tree_map
Pytrees with different static fields (node=False
). There are ways to go around this but the outcome is undefined (which static value to choose?).
I don't think you can collectively
tree_map
Pytrees with different static fields (node=False
).
It seems from the discussion at https://github.com/google/jax/discussions/13768 that this is indeed not possible.
I'm running into some issues when trying to stack a list of
Treeo.Tree
objects into a single object. I've made a short example:However, this fails with the following exception:
Versions used:
From a certain perspective this is expected because
jax.tree_map
does not apply to static (node=False
) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields. Has anyone has tried something similar and come up with a nice solution?