finch-tensor / Finch-mlir

Rewriting Finch in mlir
3 stars 1 forks source link

Discussion: `Finch-mlir` API #3

Open mtsokol opened 2 months ago

mtsokol commented 2 months ago

Hi @nullplay,

I wanted to start a discussion on Finch-MLIR <-> MLIR tensors API. In https://github.com/pydata/sparse/tree/main/sparse/mlir_backend we have an initial Tensor class which provides constructors for MLIR sparse/dense tensors:

tensor< dim0 x dim1 x ... x data_type[, sparse_format]>

Here are examples of MLIR tensors that can be created and will be passed to the API exposed by Finch-mlir:

// Dense
tensor<3x4xf64>
tensor<10xi64>
tensor<1x5x10xi32>

// CSR
#CSR = #sparse_tensor.encoding<{
    map = (i, j) -> (i : dense, j : compressed), posWidth = 64, crdWidth = 64
}>
tensor<4x4xf64, #CSR>
tensor<10x8xi64, #CSR>

// CSC
#CSC = #sparse_tensor.encoding<{
    map = (i, j) -> (j : dense, i : compressed), posWidth = 64, crdWidth = 64
}>
tensor<4x4xf32, #CSC>

// COO
#COO = #sparse_tensor.encoding<{
    map = (i, j) -> (i : compressed(nonunique), j : singleton), posWidth = 64, crdWidth = 64
}>
tensor<4x4xi64, #COO>
tensor<10x100x50xf32, #COO>

// CSF
#CSF = #sparse_tensor.encoding<{
    map = (i, j, k) -> (i : dense, j : compressed, k : compressed), posWidth = 64, crdWidth = 64
}>
tensor<4x4x10xf64, #CSF>

One of the first questions is: What would be the API to call some basic operations, let's say unary and binary elemwise operations, matmul (tensordot?) and reductions? Like:

func.func @add(%arg0: tensor<...>, %arg1: tensor<...>) -> tensor<...>
func.func @sub(%arg0: tensor<...>, %arg1: tensor<...>) -> tensor<...>
...
func.func @sum(%arg: tensor<...>, %axis: array<index>) -> tensor<...>
hameerabbasi commented 2 months ago

I would say all of this needs to be in one function, because optimising across function boundaries is difficult.

mtsokol commented 1 month ago

@nullplay so maybe let's first establish the MLIR code that would perform addition, substitution, reductions etc. using looplets, given that the input is a:

tensor< dim0 x dim1 x ... x data_type[, sparse_format]>

Where access to underlying memrefs is: