fferflo / einx

Universal Tensor Operations in Einstein-Inspired Notation for Python.
https://einx.readthedocs.io/en/stable/
MIT License
311 stars 8 forks source link

Add more examples for get_at #11

Closed NightMachinery closed 2 months ago

NightMachinery commented 2 months ago

It’s a bit confusing.

x = einx.get_at("[v] c, b t -> b t c", einn.param(name="vocab_embed"), x, v=50257, c=1024)

I get the bracket around v, but shouldn’t there be another bracket for retrieving the index from x?

fferflo commented 2 months ago

You could think of the operation more verbosely as:

einx.get_at("[v] c, b t [] -> b t c []", ...)

The elementary operation that is applied here has the signature (v), () -> (), and is vectorized over all other axes. The empty brackets are not included in einx notation though: They're ambiguous (for example all of b t [], b [] t and [] b t would be valid) and they also don't add relevant information. So einx instead assumes that tensors whose einx expressions do not contain any brackets are used as scalars in the elementary operation. Does this answer your question?

You can also find more examples here.

NightMachinery commented 2 months ago

Thanks, yes, I think I get it. I'd still appreciate more examples of get_at in the docs. The operation is rather rich, and concrete usage examples can help one form intuitions.

The examples of creating neural networks using einx could also use more PyTorch examples. Jax uses a tracing compiler which makes it very different from PyTorch, so its examples are not particularly helpful for gaining intuition about using einx with PyTorch.