stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
140 stars 9 forks source link

Support Rearrange Without Tailing Dimensions #52

Closed karan-dalal closed 5 months ago

karan-dalal commented 6 months ago

Currently, hax.rearrange requires specifying all dimension in order to change dimensions. For example:

x = [Batch, Sequence, Embedding]
x = hax.rearrange(x, "Batch Seq Embed -> Batch . Chunk Embed")

Instead, it would be great if we didn't have to include dimensions irrelevant to the rearrangement:

x = hax.rearrange(x, "Seq -> . Chunk")

Thanks!

dlwh commented 6 months ago

probably the way I'd like to support this is with this syntax: "{ (Seq: NChunk Chunk) } -> {NChunk Chunk}" which uses the "unordered" syntax to specify that we're grabbing only some dims and only partially specifying dims on the RHS. The semantics in general are probably something like: "all dims that come before (resp. after) the selected dims remain before (resp. after) the selected dims. Selected dims are made contiguous, and other dims (i.e. those that occur after some selected dim and before another) can be before or after".

Kind of a lot, but I think that's the best guarantee

dlwh commented 5 months ago

I have syntax for this in dev, basically we now allow multiple ellipses in the RHS, and it more or less does what you expect