Closed nimashoghi closed 1 month ago
This looks reasonable to me! I'd be happy to merge this as a non-draft PR.
@patrick-kidger Thanks! I was just holding off on this so I can add a simple test, which I just pushed.
I was hoping to add one more test which tests for inline (function body) annotations, but it seems like jaxtyping
doesn't catch these kinds of errors (? not sure on this). I wasn't able to get a test like below working:
#
# Test that body annotations (but no arg/return annotations) are checked
def body_annot_test(x):
y: Float32[jnp.ndarray, " b"] = x
_ = y
body_annot_test(jnp.array([1.0]))
with pytest.raises(ParamError):
body_annot_test(jnp.array(1))
Aside from this, though, this should be good to go.
This looks great to me! Thank you for the fix, and merged :)
As for inline annotations, see https://github.com/patrick-kidger/jaxtyping/issues/153#issuecomment-2028020335 and https://github.com/patrick-kidger/jaxtyping/issues/92#issuecomment-2030070762!
…omatically adding
@jaxtyped
decorator.