patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.14k stars 59 forks source link

Random instances / Hypothesis-like generation #189

Open srush opened 7 months ago

srush commented 7 months ago

Was just curious if anyone had built a way to use jaxtyping to generate random instances of the right shape specified by the constraints? Alternatively if I wanted to build that how might I hook into the constraint system.

https://hypothesis.readthedocs.io/en/latest/numpy.html#array-api

patrick-kidger commented 7 months ago

I've definitely heard this idea discussed a couple of times, but I don't know that anyone has both done it and published it open-source.

At least with JAX it should be particularly easy, as one could then do a jax.eval_shape call to abstractly evaluate the function, performing all shape-checks without having to evaluate any actual code or knowing what are legal values to pass in.

srush commented 7 months ago

okay, I might try something up. any recs on how to best get out the constraints?

patrick-kidger commented 7 months ago

As in, given an x defined by x = Float[Array, "foo bar"], you're asking how to obtain the "foo bar" from x?

This is accesible as x.dim_str, although that's only semi-public API. You could also look at x.dims to get the parsed result, although that's again only semi-public as the types it contains are not public.