Kotlin / multik

Multidimensional array library for Kotlin
https://kotlin.github.io/multik/
Apache License 2.0
647 stars 39 forks source link

3 dimensional arrays don't do dot product of sub-matrices correctly #212

Open mihbor opened 2 days ago

mihbor commented 2 days ago

3 dimensional arrays don't do dot product correctly (unless they have only one top-level element). Reproducer:

import org.jetbrains.kotlinx.multik.api.linalg.dot
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.ndarray
import org.jetbrains.kotlinx.multik.ndarray.data.get

fun main() {

  val a = mk.ndarray(
    mk[
      mk[
        mk[1.0],
        mk[0.0],
      ],
    ],
  )
  val b = mk.ndarray(
    mk[
      mk[
        mk[1.0],
        mk[0.0],
      ],
      mk[
        mk[1.0],
        mk[0.0],
      ],
    ],
  )
  println("a[0] dot a[0].T: (CORRECT)")
  println(a[0] dot a[0].transpose())

  println("\nb[0] dot b[0].T: (WRONG)")
  println(b[0] dot b[0].transpose())

  println("\nb[1] dot b[1].T: (WRONG)")
  println(b[1] dot b[1].transpose())
}

prints

a[0] dot a[0].T: (CORRECT)
[[1.0, 0.0],
[0.0, 0.0]]

b[0] dot b[0].T: (WRONG)
[[0.0, 0.0],
[0.0, 0.0]]

b[1] dot b[1].T: (WRONG)
[[0.0, 0.0],
[0.0, 0.0]]
mihbor commented 9 hours ago

A couple of observations: The problem goes away if I set the engine type to: mk.setEngine(KEEngineType) With the default engine the dot product doesn't work correctly for D2 x D1 shapes as well sometimes.