Open MilesCranmer opened 9 months ago
So it's a little less elegant, but you can do this today via
def unpack(x: Float[Array, "B C H_W"]) -> Float[Array, "B C H (H_W//H)"]:
... # magic unpacking
return y
where the ( )
brackets are just optional for readability. And this is already runtime checkable!
The reverse direction is a bit neater, as the logic (which should typically go on the RHS) is easier to read:
def pack(x: Float[Array, "B C H W"]) -> Float[Array, "B C H*W"]:
... # magic unpacking
return y
Note that this works because in each case the output shape is a function of the input shape -- rather than the other way around!
First of all, just on the syntax: I think if we were to support something like this, then I'd probably suggest using the syntax H*W
rather than (H W)
. This is because it's the syntax we already have!
And in fact, we can actually already write this:
def unpack(x: Float[Array, "B C H*W"]) -> Float[Array, "B C H W"]:
... # magic unpacking
return y
but this would raise an error if you were to do runtime type-checking, as it won't have seen H
and W
when you first call the function.
So if were to change anything, I think it would probably be to (a) allow such "incomplete" annotations when checking the arguments, and then to (b) go back and check them all again after the function has finished running and we have its output.
WDYT?
Cool! Thanks, that is great that this is already possible! I think this will already be very useful enough for me. The second option is perhaps a bit nicer if it is not too difficult to add. i.e., doing H*W -> H W
is perhaps a fractional amount better than HW -> H (HW//H)
but if it increases code complexity too much, maybe it’s not worth it?
I'd be happy to add the second notation, I'd just have to ask for a PR on it as I don't have the time to implement this myself :D
If you or anyone else feels strongly about this, then I'd be happy to explain how to tweak the jaxtyping internals to accomplish this.
For posterity I'm also time deficient at the moment due to teaching. (Anybody reading this thread; feel free to take a stab at this!)
Hey @patrick-kidger,
I'm wondering how hard it would be to have einops-like notation for packed axes? For example,
would indicate that the last axis is a flattened version of the height and width axis.
This means that if you have the full signature as:
then jaxtyping would check that
y.shape[2] * y.shape[3] == x.shape[2]
.Note that in many cases it would not be able to confirm
H
andW
individually. I think that is okay; it's just free variables. But if it can confirm the individual shapes, then it can do the type check.What do you think? Does this make sense?