willow-ahrens / Finch.jl

Sparse tensors in Julia and more! Datastructure-driven array programing language.
http://willowahrens.io/Finch.jl/
MIT License
158 stars 15 forks source link

@einsum macro #428

Closed willow-ahrens closed 4 months ago

willow-ahrens commented 6 months ago

Let's make a finch macro in the style of https://github.com/Jutho/TensorOperations.jl. We can use the design pattern of https://github.com/willow-ahrens/Finch.jl/blob/main/src/FinchNotation/syntax.jl, and build a call tree for the right hand sides of the tensors. It may be useful to keep in mind https://docs.julialang.org/en/v1/manual/metaprogramming/. It would be nice if we used the finch incrementing syntax @einsum A[I, j, k] <<max>>= B[I, j] * C[j, k] + D[i]

There's some complicated semantics regarding the creation of the tensor on the left hand side. I think we should always initialize the tensor on the left hand side, we need to come up with a syntax for what the initial value is (perhaps just a keyword init=0). A good default for the initial value is the initial_value function : https://github.com/willow-ahrens/Finch.jl/blob/25afb3fc99733f7bbcaa601975f5b40717b5be18/src/interface/mapreduce.jl#L104

As a side issue, we also need to be able to detect when any of the arguments are lazy, and if they are lazy, produce a lazy tensor instead of computing immediately. One really nice way to do this is through the use of Broadcast Styles: https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting.

willow-ahrens commented 6 months ago

A quick followup: here's an example of passing options to a macro: https://github.com/willow-ahrens/Finch.jl/blob/25afb3fc99733f7bbcaa601975f5b40717b5be18/src/execute.jl#L163-L176 though this macro just passes them to the underlying function, you could also parse the option expression with an @capture opt :=(~name, ~val)

willow-ahrens commented 6 months ago

We're using https://github.com/willow-ahrens/RewriteTools.jl to do our rewriting, so we can consult those docs for more on how the @capture and @rule macros work. Note that those macros work on FinchLogic nodes, FinchNode (@finch_program) nodes, and Julia Exprs.

willow-ahrens commented 6 months ago

@kylebd99 You mentioned wanting a function to do this, and on further reflection, I think that would be a good idea. First, we need a way to switch between eager and lazy modes (which would be accomplished at the function level). Second, we'll want something on the python side for this. Perhaps we could discuss this at the Finch meeting next week? Hopefully we can come up with ergonomic function-based approaches that allow for arbitrary pointwise expressions.

kylebd99 commented 6 months ago

That makes sense. We should definitely chat about the eager/lazy distinction. In general, should I be thinking of it as "FinchLogic=lazy" and "FinchNotation=eager"?

willow-ahrens commented 6 months ago

I think it's more that

A = Tensor()
B = Tensor()
C = LogicTensor()
D = LogicTensor()
#eager
einsum(A, B)
#lazy
einsum(C, D)
einsum(A, C)
willow-ahrens commented 6 months ago

So we've decided that in finch, we need two einsum layers; an underlying functional interface, and a high-level interface that mirrors existing finch syntax. The lower level would probably use styles (e.g. https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting or https://github.com/willow-ahrens/Finch.jl/blob/main/src/style.jl) to resolve which executor/finch query interpreter it uses based on argument preferences. For example, logictensors declare lazy style and regular tensors declare eager style.