kyonifer / koma

A scientific computing library for Kotlin. https://kyonifer.github.io/koma
Other
270 stars 23 forks source link

Overhead of accessing array elements #93

Closed peastman closed 5 years ago

peastman commented 5 years ago

I've been looking into improving the performance of accessing arrays. Suppose that array is a NDArray<Double> and you look up an element by index, say array[0,1]. It turns out there's actually a huge overhead involved. Let me walk through what happens.

The get() function is implemented as

operator fun  NDArray<Double>.get(vararg indices: Int) = getDouble(*indices)

which is simple enough. It just calls through to getDouble(), which is implemented as

fun getDouble(vararg indices: Int) = getDouble(safeNIdxToLinear(indices))

Ok, let's look up safeNIdxToLinear():

fun <T> NDArray<T>.safeNIdxToLinear(indices: IntArray) = nIdxToLinear(checkIndices(indices))

So first it calls checkIndices():

fun <T> NDArray<T>.checkIndices(indices: IntArray) = indices.also {
    val shape = shape()
    if (indices.size != shape.size)
        throw IllegalArgumentException("Cannot index an array with shape ${shape.toList()} with " +
                "anything other than ${shape.size} indices (${indices.size} given)")
    indices.forEachIndexed{ i, idx ->
        if (idx >= shape[i])
            throw IllegalArgumentException("Cannot index an array with shape ${shape.toList()} at " +
                    "${indices.toList()} (out of bounds)")
    }
}

That mostly looks reasonable, but there's a couple of hidden costs. Here's how shape() is implemented:

override fun shape(): List<Int> = shape.toList()

Notice that every call to shape() constructs a new list. That's important, because we'll be making multiple calls to it. Not only that, but it converts the primitive integers in shape into boxed Int objects, which then have to be unboxed when we access them. This is a lot of unnecessary overhead.

But we've hardly begun. Now let's look at nIdxToLinear():

fun <T> NDArray<T>.nIdxToLinear(indices: IntArray): Int {
    var out = 0
    val widthOfDims = widthOfDims()

    indices.forEachIndexed { i, idxArr ->
        out += idxArr * widthOfDims[i]
    }
    return out
}

That mostly looks efficient, except that widthOfDims is another List<Int>, so there's a cost to unboxing each element. But the call to widthOfDims() is where things really get complicated.

fun <T> NDArray<T>.widthOfDims() = shape()
        .toList()
        .accumulateRight { left, right -> left * right }
        .apply {
            add(1)
            removeAt(0)
        }

So it first calls shape(), which as we saw constructs a new List of boxed Ints. Then we immediately call toList(), which creates another List, and adds and removes some elements. Plus it calls accumulateRight() which creates yet another list, and does some more manipulations of it:

fun <T> List<T>.accumulateRight(f: (T, T) -> T)
        = this.foldRight(ArrayList<T>()) { ele, accum ->
    if (accum.isEmpty())
        accum.add(ele)
    else
        accum.add(0, f(ele, accum.first()))
    accum
}

Ok, after all that we finally have the linear index of the element to look up. That gets passed to getDouble():

    override fun getDouble(i: Int): Double {
        val ele = storage[checkLinearIndex(i)]
        return ele.toDouble()
    }

That's mostly inexpensive (storage is a DoubleArray, so no boxing is required). But we do pass it through checkLinearIndex():

fun <T> NDArray<T>.checkLinearIndex(index: Int) = index.also {
    if (index < 0)
        throw IllegalArgumentException("Negative indices are not supported")
    else size.let { n ->
        if (index >= n) {
            val an = when("$n"[0]) {
                '1','8' -> "an"
                else    -> "a"
            }
            throw IllegalArgumentException("Cannot index $an $n-element array with shape " +
                                           "${shape().toList()} at linear position $index " +
                                           "(out of bounds)")
        }
    }
}

Remember, we already went through checkIndices(), which made sure the indices were legal. So this is completely redundant.

Ok, what can we do to improve this? First, there's no need to recompute widthOfDims() every time. The content of that list never changes. We could just compute it in the constructor and return it directly. Same with shape(). There's no need to create a new List<Int> every time it's called.

Another possibility we could consider is returning the shape and widths as IntArray rather than List<Int>. That gets a bit more dangerous, since arrays are mutable. It would have to be internal only, and even then might introduce more risk of coding errors than we want. But it would improve performance by eliminating the need for unboxing.

We also should try to eliminate the duplicate range check. If we've already verified the n-dimensional indices are valid we don't need to also check the corresponding linear index.

peastman commented 5 years ago

Another option would be to change nIdxToLinear() and linearToNIdx() into methods of NDArray. Then they could access internal fields directly.

kyonifer commented 5 years ago

Yeah, what's there is definitely a first cut and full of potential optimizations. My long-term plan is to have a benchmark suite we run against all the backends on all 3 platforms and have it run in CI so we can detect any regressions that affect a single platform. But thats currently stalled by koma-tests not being trivial to port, which is blocking #56, which is in turn blocked by #77, which is blocked by https://youtrack.jetbrains.net/issue/KT-27849.

In the meantime though, doing some initial passes that take out the obviously sub-optimal code is good. I'll take a look at #94 soon.