dlangBugzillaToGithub / migration_test

0 stars 0 forks source link

std.numerics.dotProduct for fixed-size arrays #418

Open dlangBugzillaToGithub opened 13 years ago

dlangBugzillaToGithub commented 13 years ago

bearophile_hugs reported this on 2011-04-24T11:21:34Z

Transfered from https://issues.dlang.org/show_bug.cgi?id=5880

CC List

Description

A third overload for fixed-sized arrays offers:
- compile-time errors for the length mismatch instead of run-time ones (not in release build);
- allows compilers to optimize the code better because the lengths are known at compile-time;
- the fixed-size argument arrays are by reference, avoiding the copy in this case too.

/*pure*/ CommonType!(ElementType!(Range1), ElementType!(Range2))
dotProduct(Range1, Range2)(Range1 a, Range2 b)
    if (isInputRange!(Range1) && isInputRange!(Range2) &&
            !(isArray!(Range1) && isArray!(Range2)))
{
    // can't be pure yet because of length property and enforce
    enum bool haveLen = hasLength!(Range1) && hasLength!(Range2);
    static if (haveLen) enforce(a.length == b.length);
    typeof(return) result = 0;
    for (; !a.empty; a.popFront, b.popFront)
    {
        result += a.front * b.front;
    }
    static if (!haveLen) enforce(b.empty);
    return result;
}

/// Ditto
pure Unqual!(CommonType!(F1, F2))
dotProduct(F1, F2)(in F1[] avector, in F2[] bvector)
if (!isStaticArray!F1 || !isStaticArray!F2)
{
    immutable n = avector.length;
    assert(n == bvector.length);
    auto avec = avector.ptr, bvec = bvector.ptr;
    typeof(return) sum0 = 0, sum1 = 0;

    const all_endp = avec + n;
    const smallblock_endp = avec + (n & ~3);
    const bigblock_endp = avec + (n & ~15);

    for (; avec != bigblock_endp; avec += 16, bvec += 16)
    {
        sum0 += avec[0] * bvec[0];
        sum1 += avec[1] * bvec[1];
        sum0 += avec[2] * bvec[2];
        sum1 += avec[3] * bvec[3];
        sum0 += avec[4] * bvec[4];
        sum1 += avec[5] * bvec[5];
        sum0 += avec[6] * bvec[6];
        sum1 += avec[7] * bvec[7];
        sum0 += avec[8] * bvec[8];
        sum1 += avec[9] * bvec[9];
        sum0 += avec[10] * bvec[10];
        sum1 += avec[11] * bvec[11];
        sum0 += avec[12] * bvec[12];
        sum1 += avec[13] * bvec[13];
        sum0 += avec[14] * bvec[14];
        sum1 += avec[15] * bvec[15];
    }

    for (; avec != smallblock_endp; avec += 4, bvec += 4) {
        sum0 += avec[0] * bvec[0];
        sum1 += avec[1] * bvec[1];
        sum0 += avec[2] * bvec[2];
        sum1 += avec[3] * bvec[3];
    }

    sum0 += sum1;

    /* Do trailing portion in naive loop. */
    while (avec != all_endp)
        sum0 += (*avec++) * (*bvec++);

    return sum0;
}

/// Ditto
pure Unqual!(CommonType!(F1, F2))
dotProduct(F1, F2, size_t n, size_t n2)(ref const F1[n] avector, ref const F2[n2] bvector)
if (isStaticArray!(typeof(avector)) && isStaticArray!(typeof(bvector)))
{
    static assert(n == n2); // do not move this to the template constraints
    auto avec = avector.ptr, bvec = bvector.ptr;
    typeof(return) sum0 = 0, sum1 = 0;

    const all_endp = avec + n;
    const smallblock_endp = avec + (n & ~3);
    const bigblock_endp = avec + (n & ~15);

    for (; avec != bigblock_endp; avec += 16, bvec += 16)
    {
        sum0 += avec[0] * bvec[0];
        sum1 += avec[1] * bvec[1];
        sum0 += avec[2] * bvec[2];
        sum1 += avec[3] * bvec[3];
        sum0 += avec[4] * bvec[4];
        sum1 += avec[5] * bvec[5];
        sum0 += avec[6] * bvec[6];
        sum1 += avec[7] * bvec[7];
        sum0 += avec[8] * bvec[8];
        sum1 += avec[9] * bvec[9];
        sum0 += avec[10] * bvec[10];
        sum1 += avec[11] * bvec[11];
        sum0 += avec[12] * bvec[12];
        sum1 += avec[13] * bvec[13];
        sum0 += avec[14] * bvec[14];
        sum1 += avec[15] * bvec[15];
    }

    for (; avec != smallblock_endp; avec += 4, bvec += 4) {
        sum0 += avec[0] * bvec[0];
        sum1 += avec[1] * bvec[1];
        sum0 += avec[2] * bvec[2];
        sum1 += avec[3] * bvec[3];
    }

    sum0 += sum1;

    /* Do trailing portion in naive loop. */
    while (avec != all_endp)
        sum0 += (*avec++) * (*bvec++);

    return sum0;
}

unittest
{
    minidot()

    double[] a0, b0;
    assert(dotProduct(a0, b0) == 0);

    assert(dotProduct([1.0, 2.0], [4.0, 6.0]) == 16.0);
    assert(dotProduct([1, 3, -5], [4, -2, -1]) == 3);
    assert(dotProduct(iota(1, 5), iota(10, 41, 10)) == 300);

    int[4] a1 = [1, 2, 3, 4];
    int[4] b1 = [10, 20, 30, 40];
    assert(dotProduct(a1, b1) == 300);

    int[] c1 = [10, 20, 30, 40];
    assert(dotProduct(a1, c1) == 300);

    int[5] c2 = [10, 20, 30, 40, 0];
    assert(!__traits(compiles, { dotProduct(a1, c2); } )); // can't compile

    // more unittests needed
}
dlangBugzillaToGithub commented 13 years ago

bearophile_hugs commented on 2011-04-24T12:39:28Z

This doesn't work yet:

int[4] a1 = [1, 2, 3, 4];
assert(dotProduct(a1, iota(10, 41, 10)) == 300);