jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.98k stars 2.75k forks source link

Typing improvement: Preserve the wrapped callable's signature in `jax.jit` #23719

Open lebrice opened 2 days ago

lebrice commented 2 days ago

Hello there. Simple feature request / bug report:

Currently, jax.jit drops the typing signature of the wrapped callable. For example this code shows no warnings in a code editor:

import jax

@jax.jit
def foo(a: jax.Array) -> jax.Array:
    return a

foo(bob=123)  # type-checker should display a warning!

Same goes for functions or methods annotated with a functools.partial of jax.jit: The signature of the wrapped callable is dropped:

import jax
import functools

@functools.partial(jax.jit, static_argnames=["some_static_arg"])
def foo_with_static_arg(a: jax.Array, some_static_arg: Any) -> jax.Array:
    return a

foo_with_static_arg(bob=123)  # type-checker should also display a warning here!

Is there something I'm not aware of that might make this undesirable for some reason?

In the meantime, I made #23720 to address this. Let me know what you think :)

jakevdp commented 2 days ago

Crossref #14688, which was a previous attempt to solve this.