stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

Add einops-style `rearrange` #19

Closed dlwh closed 9 months ago

dlwh commented 11 months ago

rearrange allows things like

q = einops.rearrange("b h t d -> b t (h d)", q)
# undo that:
q = einops.rearrange("b t (h d) -> b h t d", h=num_heads)

which combines transposition and flatten/unflatten.

It would be nice to support this.

I think we need something like

q = q.rearrange("... pos embed", {("head", "key"): "embed"}
# or
q = q.rearrange("... pos embed", embed=("head", "key"))
# undo
q = q.rearrange("... head pos key", {"embed": (Head, Key})
# or
q = q.rearrange("... head pos key", embed=(Head, Key))

where the first means "merge head and key and name embed"

dlwh commented 9 months ago

fixed in #44