jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.43k stars 2.79k forks source link

Design suggestion: Indicate has-aux using a box #12948

Open NeilGirdhar opened 2 years ago

NeilGirdhar commented 2 years ago

Currently, we have

def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
                   has_aux: bool = False, holomorphic: bool = False,
                   allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
  ) -> Callable[..., Tuple[Any, Any]]:

which leaves the return value a mostly untyped Callable.

If instead we indicated has-aux with a box, we could do

def value_and_grad(fun: Callable[..., Y] | HasAux[Y, Aux],
                   argnums: Union[int, Sequence[int]] = 0,
                   holomorphic: bool = False, allow_int: bool = False,
                   reduce_axes: Sequence[AxisName] = ()
  ) -> Callable[..., Tuple[Y, Any]] | Callable[..., Tuple[Tuple[Y, Aux], Any]]:

This way, the output value of type Y and the auxilliary data of type Aux would be typed. typing.overload could be used to automatically select between the two return values.

I proposed a similar idea to jax-opt, where I illustrated how such a box could be designed.

apaszke commented 2 years ago

I don't follow the purpose of the HasAux union element in the fun input annotation. Still, I think that this might actually create more friction then necessary. Going over the JEP for type annotation design I see two points that might go against this proposal:

  1. Err towards simplicity: The new type signatures are much more complicated than the old ones
  2. Outputs should be strictly typed: the old annotation indicates that the result is always a function that returns a pair of things. The new one says that once you apply it, you are left with Tuple[Y, Any] | Tuple[Tuple[Y, Aux], Any]. In particular, if you have a value of a union type, a type checker can complain if you just assume that the value corresponds to a concrete case without checking. But the case is not arbitrary! It is always implied by the has_aux argument!

Perhaps a more interesting technique would be to create three overloads for value_and_grad: one with has_aux having type Literal[True], one of type Literal[False] and one of type bool. The first two overloads would return a more precise annotation of the return type, while the last one would have exactly the same generic signature we provide today. WDYT? This is exactly the use case from mypy docs.

NeilGirdhar commented 2 years ago

Err towards simplicity: The new type signatures are much more complicated than the old ones

Yeah, I'm with you. I'm not convinced this is a big win for Jax. I was more proposing this for jax-opt, where the benefits are significantly greater.

I'm not sure it's "much more complicated though". It just replaces one flag with a boxing class. Jax already does this with custom_vjp and custom_jvp, which are also classes that wrap a callable.

In particular, if you have a value of a union type, a type checker can complain if you just assume that the value corresponds to a concrete case without checking

I don't think this is a problem. As I mentioned in the issue, you would just have overloads so that exactly one of the cases were chosen by the type checker. In fact, the overloads can be added regardless of the boxing class idea.

Perhaps a more interesting technique would be to create three overloads for value_and_grad: one with has_aux having type Literal[True], one of type Literal[False] and one of type bool

Exactly, that should definitely be done regardless.

This was more an issue with jax-opt because overloads are impossible there. And when I proposed, they suggested it would be nice if their design matched Jax's, so I thought of proposing something here.

apaszke commented 2 years ago

I think that adding overloads for literals is a good idea. But I'm not 100% sold on the boxing class, since in most cases literal overloads should do the trick (how many people use a dynamic value of has_aux after all)