patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

Add Generic Type Hints to tree_at Function #890

Closed jsonmona closed 2 weeks ago

jsonmona commented 2 weeks ago

Current Implementation

The tree_at function currently has this type signature:

def tree_at(
    where: Callable[[PyTree], Union[_Node, Sequence[_Node]]],
    pytree: PyTree,
    replace: Union[Any, Sequence[Any]] = sentinel,
    replace_fn: Callable[[_Node], Any] = sentinel,
    is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Any:
# I added implicit "-> Any"

Proposed Change

Add a generic type variable to improve type hints:

from typing import TypeVar

SpecificPyTree = TypeVar("SpecificPyTree")

def tree_at(
    where: Callable[[SpecificPyTree], Union[_Node, Sequence[_Node]]],
    pytree: SpecificPyTree,
    replace: Union[Any, Sequence[Any]] = sentinel,
    replace_fn: Callable[[_Node], Any] = sentinel,
    is_leaf: Optional[Callable[[Any], bool]] = None,
) -> SpecificPyTree:

This way, IDE users get type hint when writing lambda function for where, and when using the returned value from tree_at (which currently is Any).

Type Soundness

  1. PyTree type is an alias for Any, so SpecificPyTree needs no bounds.
  2. tree_at definitely returns same type as pytree.
  3. Callback in where expects same type as pytree, minus duck typing.
  4. Should cause no issue for equinox devs since the function internals already use Any.

Potential Concerns

There might be existing usages that would trigger type errors with this change, though I couldn't come up with any.

jsonmona commented 2 weeks ago

I think I understand why this isn't possible. Users may replace PyTree node with another type. e.g. to replace BatchNorm with LayerNorm.

I'm going to leave this open in case I was wrong, but please feel free to close it.

patrick-kidger commented 2 weeks ago

Yup, you've exactly hit on the problem. The input may be a pytree that is a tuple[int], but the returned value could become a tuple[str], for example.

I think on that basis I'm going to close this issue, but thank you for thinking of this!

jsonmona commented 2 weeks ago

On the second thought, that was only for the return value. Annotating the where callback should be still valid.

Like this:

from typing import TypeVar

SpecificPyTree = TypeVar("SpecificPyTree")

def tree_at(
    where: Callable[[SpecificPyTree], Union[_Node, Sequence[_Node]]],
    pytree: SpecificPyTree,
    replace: Union[Any, Sequence[Any]] = sentinel,
    replace_fn: Callable[[_Node], Any] = sentinel,
    is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Any:

Notice that return type is now Any, but where callback stays to be SpecificPyTree.

patrick-kidger commented 2 weeks ago

Unfortunately not even that:

import equinox as eqx

pytree: list[str] = ['hi']

def where(pytree):
    print(pytree)
    return pytree[0]

eqx.tree_at(where, pytree, 'bye')
# ['hi']
# [<equinox._tree._LeafWrapper object at 0x100de0990>]

that second one is definitely not a list[str].

(To explain, JAX treats pytrees as 'structural only' and that they can accept any leaf type -- a common example of this is vmap(..., in_axes=...) using the same structure as the input but instead having leaf type None | int. We use this flexibility in eqx.tree_at to enable it to do what it does.)

jsonmona commented 2 weeks ago

That's disappointing :(

Thank you for the explanation, though!