stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

Einsum #63

Closed dlwh closed 8 months ago

dlwh commented 8 months ago

Adds a new einsum syntax to Haliax that has three modes: "ordered", "unordered", and "output only".

Examples:

JAX Haliax
jnp.dot(z, x) hax.dot(z, x, axis="batch")
jnp.matmul(z, x) hax.dot(z, x, axis="batch")
jnp.dot(w, x.T) hax.dot(w, x, axis="embed")
jnp.einsum("ij,j -> i", x, w) hax.dot(x, w, axis="embed")
jnp.einsum("i,ij,ij,j -> i", z, x, y, w) hax.dot(z, x, y, w, axis="embed")
jnp.einsum("ij,j -> ji", x, w) hax.dot(x, w, axis=(), out_axes=("embed", "batch")
jnp.einsum("bhwc,ce -> bhwe", im, w2) hax.einsum("b h w c, c e -> b h w e", im, w2)
jnp.einsum("...c,ce -> ...e", im, w2) hax.einsum("... c, c e -> ... e", im, w2)
jnp.einsum("bhwc,ce -> bhwe", im, w2) hax.einsum("{c embed} -> embed", im, w2)
jnp.einsum("bhwc,ce -> bhwe", im, w2) hax.einsum("-> batch h w embed", im, w2)
jnp.einsum("bhwc,ce -> bhwce", im, w2) hax.einsum("{...} -> ...", im, w2)
jnp.einsum("bhwc,ce -> ", im, w2) hax.einsum("{...} -> ", im, w2)
jnp.einsum("bhwc,ce -> bhwce", im, w2) hax.dot(im, w2, axis=())
jnp.einsum("bhwc,ce -> ", im, w2) hax.dot(im, w2, axis=None)

Fixes #2