patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.24k stars 63 forks source link

einops-like packing notation #180

Open MilesCranmer opened 9 months ago

MilesCranmer commented 9 months ago

Hey @patrick-kidger,

I'm wondering how hard it would be to have einops-like notation for packed axes? For example,

Float[Array, "B C (H W)"]

would indicate that the last axis is a flattened version of the height and width axis.

This means that if you have the full signature as:

def unpack(x: Float[Array, "B C (H W)"]) -> Float[Array, "B C H W"]:
    ... # magic unpacking
    return y

then jaxtyping would check that y.shape[2] * y.shape[3] == x.shape[2].

Note that in many cases it would not be able to confirm H and W individually. I think that is okay; it's just free variables. But if it can confirm the individual shapes, then it can do the type check.

What do you think? Does this make sense?

patrick-kidger commented 9 months ago

What you can do today

So it's a little less elegant, but you can do this today via

def unpack(x: Float[Array, "B C H_W"]) -> Float[Array, "B C H (H_W//H)"]:
    ... # magic unpacking
    return y

where the ( ) brackets are just optional for readability. And this is already runtime checkable!

The reverse direction is a bit neater, as the logic (which should typically go on the RHS) is easier to read:

def pack(x: Float[Array, "B C H W"]) -> Float[Array, "B C H*W"]:
    ... # magic unpacking
    return y

Note that this works because in each case the output shape is a function of the input shape -- rather than the other way around!

What we could do tomorrow

First of all, just on the syntax: I think if we were to support something like this, then I'd probably suggest using the syntax H*W rather than (H W). This is because it's the syntax we already have!

And in fact, we can actually already write this:

def unpack(x: Float[Array, "B C H*W"]) -> Float[Array, "B C H W"]:
    ... # magic unpacking
    return y

but this would raise an error if you were to do runtime type-checking, as it won't have seen H and W when you first call the function.

So if were to change anything, I think it would probably be to (a) allow such "incomplete" annotations when checking the arguments, and then to (b) go back and check them all again after the function has finished running and we have its output.

WDYT?

MilesCranmer commented 9 months ago

Cool! Thanks, that is great that this is already possible! I think this will already be very useful enough for me. The second option is perhaps a bit nicer if it is not too difficult to add. i.e., doing H*W -> H W is perhaps a fractional amount better than HW -> H (HW//H) but if it increases code complexity too much, maybe it’s not worth it?

patrick-kidger commented 9 months ago

I'd be happy to add the second notation, I'd just have to ask for a PR on it as I don't have the time to implement this myself :D

If you or anyone else feels strongly about this, then I'd be happy to explain how to tweak the jaxtyping internals to accomplish this.

MilesCranmer commented 9 months ago

For posterity I'm also time deficient at the moment due to teaching. (Anybody reading this thread; feel free to take a stab at this!)