tensil-ai / tensil

Open source machine learning accelerators
https://www.tensil.ai
Other
348 stars 28 forks source link

Support for GEMM transpose #46

Closed trulyspinach closed 2 years ago

trulyspinach commented 2 years ago

It seems that emitLayerGemm does not honors the transpose attribute in ONNX's frontend, which can sometime cause confusing error to be reported in later stage that points to multiplying matrix with unmatched dimension or completely error output without warning.

For my own purpose I was able to get this around by performing a compile time transpose by adding for following code in emitLayerGemm:

    val transA = 
      matMulProto.attribute
        .filter(a => a.name.get contains "transA")
        .length > 0

    val weightsTensor = 
      if (
        matMulProto.attribute
          .filter(a => a.name.get contains "transB")
          .length > 0 && !transA
      ) {
        println("Engaging transpose!")
        getTensor(tensorProtos(matMulProto.input(1))).transpose(Array(1,0))
      }
      else getTensor(tensorProtos(matMulProto.input(1)))

Please let me know a better way to implement this, I'll be glad to open a pull request.

tdb-alcorn commented 2 years ago

Closed by #45 thanks to @trulyspinach