Closed mtsokol closed 6 months ago
Yeah, perhaps we need to define einsum better in Finch and express both matmul
and tensordot
in terms of that.
Side note:
I think we should implement mul!
in Julia for Finch Tensor and call that here.
@willow-ahrens The PR is ready from my side! (I also want to benchmark these functions)
For now I constrained matmul
input to 2D tensors, once we have mul
that can perform matmul on stack of matrices I will come back to this.
Tracking issue #21
Hi @willow-ahrens @hameerabbasi,
This WIP PR introduces
finch.tensordot(x1, x2, axes)
andfinch.matmul(x1, x2)
/x1 @ x2
. I still need to complete exhaustive testing.@willow-ahrens I've got one question regarding
matmul
. I just noticed that there's a slight difference betweentensordot
andmatmul
in Array API for >2D input, namelytensordot
aggregates non-contracted dims from one input and the other, wherematmul
takes two innermost dims for multiplication, and the rest of dims is treated as a stack/batch dimensions.Here's a NumPy code showing it (described in notes in docs):
So I'm not sure I can implement
matmul
asself.tensordot(other, axes=((-1,), (-2,)))
. WDYT?