Closed NightMachinery closed 2 months ago
You could think of the operation more verbosely as:
einx.get_at("[v] c, b t [] -> b t c []", ...)
The elementary operation that is applied here has the signature (v), () -> ()
, and is vectorized over all other axes. The empty brackets are not included in einx notation though: They're ambiguous (for example all of b t []
, b [] t
and [] b t
would be valid) and they also don't add relevant information. So einx instead assumes that tensors whose einx expressions do not contain any brackets are used as scalars in the elementary operation. Does this answer your question?
You can also find more examples here.
Thanks, yes, I think I get it. I'd still appreciate more examples of get_at
in the docs. The operation is rather rich, and concrete usage examples can help one form intuitions.
The examples of creating neural networks using einx could also use more PyTorch examples. Jax uses a tracing compiler which makes it very different from PyTorch, so its examples are not particularly helpful for gaining intuition about using einx with PyTorch.
It’s a bit confusing.
I get the bracket around v, but shouldn’t there be another bracket for retrieving the index from x?