Closed jsonmona closed 2 weeks ago
I think I understand why this isn't possible. Users may replace PyTree node with another type. e.g. to replace BatchNorm
with LayerNorm
.
I'm going to leave this open in case I was wrong, but please feel free to close it.
Yup, you've exactly hit on the problem. The input may be a pytree that is a tuple[int]
, but the returned value could become a tuple[str]
, for example.
I think on that basis I'm going to close this issue, but thank you for thinking of this!
On the second thought, that was only for the return value. Annotating the where
callback should be still valid.
Like this:
from typing import TypeVar
SpecificPyTree = TypeVar("SpecificPyTree")
def tree_at(
where: Callable[[SpecificPyTree], Union[_Node, Sequence[_Node]]],
pytree: SpecificPyTree,
replace: Union[Any, Sequence[Any]] = sentinel,
replace_fn: Callable[[_Node], Any] = sentinel,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Any:
Notice that return type is now Any
, but where
callback stays to be SpecificPyTree
.
Unfortunately not even that:
import equinox as eqx
pytree: list[str] = ['hi']
def where(pytree):
print(pytree)
return pytree[0]
eqx.tree_at(where, pytree, 'bye')
# ['hi']
# [<equinox._tree._LeafWrapper object at 0x100de0990>]
that second one is definitely not a list[str]
.
(To explain, JAX treats pytrees as 'structural only' and that they can accept any leaf type -- a common example of this is vmap(..., in_axes=...)
using the same structure as the input but instead having leaf type None | int
. We use this flexibility in eqx.tree_at
to enable it to do what it does.)
That's disappointing :(
Thank you for the explanation, though!
Current Implementation
The tree_at function currently has this type signature:
Proposed Change
Add a generic type variable to improve type hints:
This way, IDE users get type hint when writing lambda function for
where
, and when using the returned value fromtree_at
(which currently isAny
).Type Soundness
PyTree
type is an alias forAny
, soSpecificPyTree
needs no bounds.tree_at
definitely returns same type aspytree
.where
expects same type aspytree
, minus duck typing.Any
.Potential Concerns
There might be existing usages that would trigger type errors with this change, though I couldn't come up with any.