finch-tensor / finch-tensor-python

Sparse and Structured Tensor Programming in Python
MIT License
8 stars 3 forks source link

API: Implement `tensordot` and `matmul` #22

Closed mtsokol closed 6 months ago

mtsokol commented 7 months ago

Tracking issue #21

Hi @willow-ahrens @hameerabbasi,

This WIP PR introduces finch.tensordot(x1, x2, axes) and finch.matmul(x1, x2)/x1 @ x2. I still need to complete exhaustive testing.


@willow-ahrens I've got one question regarding matmul. I just noticed that there's a slight difference between tensordot and matmul in Array API for >2D input, namely tensordot aggregates non-contracted dims from one input and the other, where matmul takes two innermost dims for multiplication, and the rest of dims is treated as a stack/batch dimensions.

Here's a NumPy code showing it (described in notes in docs):

  >>> a = np.ones([9, 5, 7, 4])
  >>> c = np.ones([9, 5, 4, 3])
  >>> np.tensordot(a, c, axes=((-1,), (-2,))).shape
  (9, 5, 7, 9, 5, 3)
  >>> np.matmul(a, c).shape
  (9, 5, 7, 3)
  >>> # n is 7, k is 4, m is 3

So I'm not sure I can implement matmul as self.tensordot(other, axes=((-1,), (-2,))). WDYT?

willow-ahrens commented 7 months ago

Yeah, perhaps we need to define einsum better in Finch and express both matmul and tensordot in terms of that.

Side note:

I think we should implement mul! in Julia for Finch Tensor and call that here.

mtsokol commented 6 months ago

@willow-ahrens The PR is ready from my side! (I also want to benchmark these functions)

For now I constrained matmul input to 2D tensors, once we have mul that can perform matmul on stack of matrices I will come back to this.