Closed dlwh closed 8 months ago
Adds a new einsum syntax to Haliax that has three modes: "ordered", "unordered", and "output only".
einsum
Examples:
jnp.dot(z, x)
hax.dot(z, x, axis="batch")
jnp.matmul(z, x)
jnp.dot(w, x.T)
hax.dot(w, x, axis="embed")
jnp.einsum("ij,j -> i", x, w)
hax.dot(x, w, axis="embed")
jnp.einsum("i,ij,ij,j -> i", z, x, y, w)
hax.dot(z, x, y, w, axis="embed")
jnp.einsum("ij,j -> ji", x, w)
hax.dot(x, w, axis=(), out_axes=("embed", "batch")
jnp.einsum("bhwc,ce -> bhwe", im, w2)
hax.einsum("b h w c, c e -> b h w e", im, w2)
jnp.einsum("...c,ce -> ...e", im, w2)
hax.einsum("... c, c e -> ... e", im, w2)
hax.einsum("{c embed} -> embed", im, w2)
hax.einsum("-> batch h w embed", im, w2)
jnp.einsum("bhwc,ce -> bhwce", im, w2)
hax.einsum("{...} -> ...", im, w2)
jnp.einsum("bhwc,ce -> ", im, w2)
hax.einsum("{...} -> ", im, w2)
hax.dot(im, w2, axis=())
hax.dot(im, w2, axis=None)
Fixes #2
Adds a new
einsum
syntax to Haliax that has three modes: "ordered", "unordered", and "output only".Examples:
jnp.dot(z, x)
hax.dot(z, x, axis="batch")
jnp.matmul(z, x)
hax.dot(z, x, axis="batch")
jnp.dot(w, x.T)
hax.dot(w, x, axis="embed")
jnp.einsum("ij,j -> i", x, w)
hax.dot(x, w, axis="embed")
jnp.einsum("i,ij,ij,j -> i", z, x, y, w)
hax.dot(z, x, y, w, axis="embed")
jnp.einsum("ij,j -> ji", x, w)
hax.dot(x, w, axis=(), out_axes=("embed", "batch")
jnp.einsum("bhwc,ce -> bhwe", im, w2)
hax.einsum("b h w c, c e -> b h w e", im, w2)
jnp.einsum("...c,ce -> ...e", im, w2)
hax.einsum("... c, c e -> ... e", im, w2)
jnp.einsum("bhwc,ce -> bhwe", im, w2)
hax.einsum("{c embed} -> embed", im, w2)
jnp.einsum("bhwc,ce -> bhwe", im, w2)
hax.einsum("-> batch h w embed", im, w2)
jnp.einsum("bhwc,ce -> bhwce", im, w2)
hax.einsum("{...} -> ...", im, w2)
jnp.einsum("bhwc,ce -> ", im, w2)
hax.einsum("{...} -> ", im, w2)
jnp.einsum("bhwc,ce -> bhwce", im, w2)
hax.dot(im, w2, axis=())
jnp.einsum("bhwc,ce -> ", im, w2)
hax.dot(im, w2, axis=None)
Fixes #2