stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

Maybe warn if `dot` reduces a dim that isn't in common between args? #20

Open dlwh opened 11 months ago

dlwh commented 11 months ago

Someone did:

attn_output = hax.dot("key_pos", attn, v)

but forgot that v had a pos dimension not a key_pos dimension (it should have had a key_pos but they forgot to rename. attn had both, so this produced a result with the right shape, but it didn't do what they wanted.

This is a valid operation and it's not inconceivable someone could want it: it's like einsum("ij,je->je, attn, v) or hax.sum(attn, "key_pos") * v. However it seems like an uncommon thing? should we maybe issue a warning? or have a "paranoid debug" mode that logs warnings for things like this?