patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Remove `has_annotated_args`/`has_annotated_return` check for when aut… #205

Closed nimashoghi closed 1 month ago

nimashoghi commented 1 month ago

…omatically adding @jaxtyped decorator.

patrick-kidger commented 1 month ago

This looks reasonable to me! I'd be happy to merge this as a non-draft PR.

nimashoghi commented 1 month ago

@patrick-kidger Thanks! I was just holding off on this so I can add a simple test, which I just pushed.

I was hoping to add one more test which tests for inline (function body) annotations, but it seems like jaxtyping doesn't catch these kinds of errors (? not sure on this). I wasn't able to get a test like below working:

#
# Test that body annotations (but no arg/return annotations) are checked

def body_annot_test(x):
    y: Float32[jnp.ndarray, " b"] = x
    _ = y

body_annot_test(jnp.array([1.0]))
with pytest.raises(ParamError):
    body_annot_test(jnp.array(1))

Aside from this, though, this should be good to go.

patrick-kidger commented 1 month ago

This looks great to me! Thank you for the fix, and merged :)

As for inline annotations, see https://github.com/patrick-kidger/jaxtyping/issues/153#issuecomment-2028020335 and https://github.com/patrick-kidger/jaxtyping/issues/92#issuecomment-2030070762!