Closed nahaharo closed 4 weeks ago
Hello.
I'm currently using jax typing with mypy.
when I using below test code, jaxtyped is untyped.
Here is my test code and mypy.ini.
I think the solution for this issue is adding type annotation to jaxtyped.
from functools import partial import jax import jax.numpy as jnp import numpy as np from jaxtyping import Float, Array, jaxtyped from typeguard import typechecked as typechecker from datetime import datetime @jaxtyped(typechecker=typechecker) @partial(jax.jit, static_argnames=["scale"]) def test_linalg( a: Float[Array, "m n"], b: Float[Array, "n k"], scale: float ) -> Float[Array, "m k-1"]: return scale * (a @ b)[:, :-1] if __name__ == "__main__": a = jnp.array(np.random.randn(5, 5)) b = jnp.array(np.random.randn(5, 6)) start_time = datetime.now() c = test_linalg(a, b, 2) end_time = datetime.now() print('Duration: {}'.format(end_time - start_time)) start_time = datetime.now() d = test_linalg(a, b, 2) end_time = datetime.now() print('Duration: {}'.format(end_time - start_time))
[mypy] python_version = 3.10 plugins = numpy.typing.mypy_plugin cache_dir = .mypy_cache/strict allow_redefinition = True strict_optional = True show_error_codes = True show_column_numbers = True warn_no_return = True disallow_any_unimported = True strict = True implicit_reexport = False warn_unused_ignores = False
After reviewing my code, this is not the thing that need to handle in this library.
Hello.
I'm currently using jax typing with mypy.
when I using below test code, jaxtyped is untyped.
Here is my test code and mypy.ini.
I think the solution for this issue is adding type annotation to jaxtyped.