brentyi / jax_dataclasses

Pytrees + dataclasses ❤️
MIT License
61 stars 6 forks source link

Add support for jaxtyping #6

Open lucagrementieri opened 1 year ago

lucagrementieri commented 1 year ago

This PR adds support for jaxtyping annotations preserving all the features and checks on tensor dimensions.

The PR doesn't update the README, since it could become messy very easily. I'll wait further indications to update the README.

Close #5.

brentyi commented 1 year ago

Thanks! Looks reasonable overall, my main concern is the private jaxtyping imports. I assume there's no way to get around this?

lucagrementieri commented 1 year ago

_MetaAbstractArray is the base class of all jaxtyping types and annotations so the check isinstance(type_hint, _MetaAbstractArray) is the best way to identify jaxtyping annotations. Surely there are workarounds to not use it, but they will be more fragile and less elegant.

For _NamedVariadicDim, I think there is no simple workaround because this class is required to support variadic dimensions, like the batch dimension.

brentyi commented 1 year ago

Okay, makes sense! It's definitely not ideal but having support for jaxtyping here seems useful enough to warrant it. I like how we don't have to worry about import hooks or @jaxtyped for the shape checks checking/getting the batch axes.

(cc @patrick-kidger for any warnings, are there any plans to rework the internals of jaxtyping?)

I can handle the rest of the PR. Some TODOs would be:

patrick-kidger commented 1 year ago

There aren't any current plans. But taking a quick glance at your code, I think this will fail for a type hint of the form tuple[Float[Array, "foo"], ...], i.e. one in which the array is nested within another type hint? (I didn't check that closely though, so I might be wrong.)

Anyway, jaxtyping hints are expected to be validated using a runtime type checker, such as typeguard or beartype. I'd recommend that you simply do the same thing, as they'll handle the details for you: both the nesting above, and avoiding the need to access private jaxtyping functionality.

Side note: if you're working on a project like this then you may find Equinox interesting. In particular equinox.Module is also a dataclass-pytree combo, with most (all?) of the expected bells-and-whistles: serialisation, immutability etc. (I suppose I've not really tested static type checking, as I'm mostly a non-user of that.)

I like the neat syntax of your copy_and_mutate, by the way. (Equinox's equivalent is equinox.tree_at, which is safer but a bit harder to use.)

brentyi commented 1 year ago

Thanks!

I've also been following Equinox; definitely the "how to build pytrees" + tooling compatibility landscapes have improved quite a bit since I started jax_dataclasses in ~late 2020. For now I think mypy compatibility + the Static[] API + copy_and_mutate are still nice enough for me to keep the library around, but sooner or later I should revisit whether the library still makes sense given developments in equinox, flax, etc (especially if a copy_and_mutate-style API is merged into flax https://github.com/google/flax/pull/2735).

I also agree that typeguard or beartype makes sense for asserting that the shapes are correct, but for this PR the main purpose is to support jaxtyping annotations for the dataclass.get_batch_axes() helper that currently works for jax_dataclasses-proprietary shape annotations.

For this we need to figure out which axes in the array shapes correspond to the variadic dimension, which leaves the options of: (a) touching the private bits of jaxtyping, (b) trying to convince @patrick-kidger to expose a public API for reasoning about jaxtyping types*, or (c) not implementing this functionality.

*maybe something like: (jaxtyping type, array) -> labels for each axis in the array. Any chance you're open to something like this? (understand if not)

patrick-kidger commented 1 year ago

You should be able to replace isinstance(type_hint, _MetaAbstractArray) with issubclass(type_hint, AbstractArray). (Which is public API.)

At that point I can see that you'd want to modify its dimensions. I think the best way to do this would be to submit a PR against jaxtyping that records cls and item here:

https://github.com/google/jaxtyping/blob/59e8fb0d18325f990a9d59ee35e90c04b699cab8/jaxtyping/array_types.py#L400

so that you can then look these up, modify these as desired, and then recreate the type hint through the public jaxtyping API (e.g. cls[item] to recreate the same hint).