Open NeilGirdhar opened 2 years ago
I don't follow the purpose of the HasAux
union element in the fun
input annotation. Still, I think that this might actually create more friction then necessary. Going over the JEP for type annotation design I see two points that might go against this proposal:
Tuple[Y, Any] | Tuple[Tuple[Y, Aux], Any]
. In particular, if you have a value of a union type, a type checker can complain if you just assume that the value corresponds to a concrete case without checking. But the case is not arbitrary! It is always implied by the has_aux
argument!Perhaps a more interesting technique would be to create three overloads for value_and_grad
: one with has_aux
having type Literal[True]
, one of type Literal[False]
and one of type bool
. The first two overloads would return a more precise annotation of the return type, while the last one would have exactly the same generic signature we provide today. WDYT? This is exactly the use case from mypy docs.
Err towards simplicity: The new type signatures are much more complicated than the old ones
Yeah, I'm with you. I'm not convinced this is a big win for Jax. I was more proposing this for jax-opt, where the benefits are significantly greater.
I'm not sure it's "much more complicated though". It just replaces one flag with a boxing class. Jax already does this with custom_vjp
and custom_jvp
, which are also classes that wrap a callable.
In particular, if you have a value of a union type, a type checker can complain if you just assume that the value corresponds to a concrete case without checking
I don't think this is a problem. As I mentioned in the issue, you would just have overloads so that exactly one of the cases were chosen by the type checker. In fact, the overloads can be added regardless of the boxing class idea.
Perhaps a more interesting technique would be to create three overloads for value_and_grad: one with has_aux having type Literal[True], one of type Literal[False] and one of type bool
Exactly, that should definitely be done regardless.
This was more an issue with jax-opt because overloads are impossible there. And when I proposed, they suggested it would be nice if their design matched Jax's, so I thought of proposing something here.
I think that adding overloads for literals is a good idea. But I'm not 100% sold on the boxing class, since in most cases literal overloads should do the trick (how many people use a dynamic value of has_aux
after all)
Currently, we have
which leaves the return value a mostly untyped
Callable
.If instead we indicated has-aux with a box, we could do
This way, the output value of type
Y
and the auxilliary data of typeAux
would be typed.typing.overload
could be used to automatically select between the two return values.I proposed a similar idea to jax-opt, where I illustrated how such a box could be designed.