stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

haliax.dot: allow defining output order a la einsum #2

Open dlwh opened 1 year ago

dlwh commented 1 year ago

Sidi asked for something like this, and I think it's a good idea. (He asked for proper einsum support, which is also a good idea)

cf stanford-crfm/levanter#107

I'd like to support something like:

haliax.dot(Embed, key, query, out_axes=(..., Head, KeySeqLen, SeqLen))

which would force the order

reachtarunhere commented 6 months ago

I really like Haliax so far I do think that we should proper einsum support. To me the string often doubles as documentation.

dlwh commented 6 months ago

yeah I was thinking that too. I added out_axes in the secret-ish dev branch. WDYT the syntax should be like?

dlwh commented 6 months ago

(and thanks!)

reachtarunhere commented 6 months ago

yeah I was thinking that too. I added out_axes in the secret-ish dev branch. WDYT the syntax should be like?

Ideally I would prefer a similar syntax to einops as shown here.

I very much prefer the explicit einsum("i j, j k -> i k", mat1, mat2) vs the dot syntax currently in the lib.

We can enforce that instead of random i j k the real axes names we have in haliax are used on the left hand side.

This does have some disadvantage over the dot method in terms of privacy of axes. For something like bmm the above code will break while the dot code will work just fine. This can be countered by batching over the axes not mentioned?

dlwh commented 6 months ago

Yeah, I am coming around to this point of view. WDYT about the syntax for new-einops-style rearrange, particularly the dev version?

We could support this syntax for dot with something like:

reachtarunhere commented 6 months ago

Looks great to me. Pretty much what I am looking for (maybe except the last one haha)

We can also have a hax.einsum which calls hax.dot after resolving all this stuff instead of expanding hax.dot

dlwh commented 5 months ago

@reachtarunhere any chance I could get you to look at https://github.com/stanford-crfm/haliax/pull/63/files#diff-b1aa00624eecf36f969b62aaee977cfac454841fa3dc40f480759e68bda5473bR57 and lmk what you think? Just asking you to glance at the docs, but if you want to look deeper that would be lovely too :-)

reachtarunhere commented 5 months ago

@dlwh just back from vacation. Happy to take a look later today :)