Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
29.98k
stars
2.75k
forks
source link
Typing improvement: Preserve the wrapped callable's signature in `jax.jit` #23719
Open
lebrice opened 2 days ago
Hello there. Simple feature request / bug report:
Currently,
jax.jit
drops the typing signature of the wrapped callable. For example this code shows no warnings in a code editor:Same goes for functions or methods annotated with a
functools.partial
ofjax.jit
: The signature of the wrapped callable is dropped:Is there something I'm not aware of that might make this undesirable for some reason?
In the meantime, I made #23720 to address this. Let me know what you think :)