patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

`tree_at` silently skips `__check_init__` #872

Closed jeertmans closed 1 month ago

jeertmans commented 1 month ago

Hi!

I guess this is probably a "feature" of tree_at, but I recently discovered that it does not run __check_init__, which means it can bypass checks, leading to malformed data.

I could not find any reference to this in the documentation, so (1) is it possible to still have __check_init__ run after some tree_at surgery, and (2) would it be possible to document that in the docstring of the tree_at function?

As always, thanks for your great tool :-)

patrick-kidger commented 1 month ago

This is intentional! tree_at skips this in the same way that it skips __init__ itself.

As for why, it's common for these functions to be used to assert that certain invariants hold (e.g. assert self.some_value > 0), but it's also common to use tree_at in cases where the tree is being used polymorphic with respect to leaf type (e.g. int_tree = tree_map(lambda _: 0, some_tree); int_tree = tree_at(lambda t: t.s.u.v, int_tree, 1) for constructing the argument for vmap(..., in_axes=...))

So if we ran __check_init__ then we wouldn't be able to express this!

I'd be happy to take a PR clarifying the documentation on this. (And if you want to run __check_init__ then I recommend creating your own tree_at wrapper that does this. :) )

jeertmans commented 1 month ago

Thanks for your reply @patrick-kidger! I had doubts this was intentional, just wanted to have clarification and also think this deserves to be documented :-)

I did a small PR to this end.