google-deepmind / alphadev

Apache License 2.0
674 stars 67 forks source link

Benchmark #2

Open 99991 opened 1 year ago

99991 commented 1 year ago

I was curious how well the proposed sorting functions for three, four or five integers would perform on my hardware, so I wrote a quick benchmark. It seems like there is almost no difference between AlphaDev and a simple sorting network, but perhaps someone finds the results interesting anyway, so I thought I'd share.

Benchmark results

Sort5AlphaDev: 17.367800 cycles on average, checksum -1224191569
Sort5Naive: 17.378000 cycles on average, checksum -1224191569
AlphaDev is 0.058729 % faster than a naive sorting network

Sort5AlphaDev: 17.384000 cycles on average, checksum -1224191569
Sort5Naive: 17.381600 cycles on average, checksum -1224191569
AlphaDev is -0.013806 % faster than a naive sorting network

Sort5AlphaDev: 17.505800 cycles on average, checksum -1224191569
Sort5Naive: 17.371000 cycles on average, checksum -1224191569
AlphaDev is -0.770031 % faster than a naive sorting network

Benchmark code

// clang -O3 main.c -o main && ./main
#include <stdio.h>
#include <stdint.h>
#include <assert.h>
#include <stdlib.h>

// N_SORT can be 3, 4 or 5
#define N_SORT 5

void Sort3AlphaDev(int* buffer) {
  asm volatile(
      "mov 0x4(%0), %%eax            \n"
      "mov 0x8(%0), %%ecx            \n"
      "cmp %%eax, %%ecx              \n"
      "mov %%eax, %%edx              \n"
      "cmovl %%ecx, %%edx            \n"
      "mov (%0), %%r8d               \n"
      "cmovg %%ecx, %%eax            \n"
      "cmp %%r8d, %%eax              \n"
      "mov %%r8d, %%ecx              \n"
      "cmovl %%eax, %%ecx            \n"
      "cmovle %%r8d, %%eax           \n"
      "mov %%eax, 0x8(%0)            \n"
      "cmp %%ecx, %%edx              \n"
      "cmovle %%edx, %%r8d           \n"
      "mov %%r8d, (%0)               \n"
      "cmovg %%edx, %%ecx            \n"
      "mov %%ecx, 0x4(%0)            \n"
      : "+r"(buffer)
      :
      : "eax", "ecx", "edx", "r8d", "memory");
}

void Sort4AlphaDev(int* buffer) {
  asm volatile(
      "mov 0x8(%0), %%eax            \n"
      "mov (%0), %%ecx               \n"
      "mov 0x4(%0), %%edx            \n"
      "cmp %%eax, %%ecx              \n"
      "mov %%eax, %%r8d              \n"
      "cmovl %%ecx, %%r8d            \n"
      "cmovl %%eax, %%ecx            \n"
      "mov 0xc(%0), %%r9d            \n"
      "cmp %%r9d, %%edx              \n"
      "mov %%r9d, %%eax              \n"
      "cmovl %%edx, %%eax            \n"
      "cmovl %%r9d, %%edx            \n"
      "cmp %%eax, %%r8d              \n"
      "mov %%eax, %%r9d              \n"
      "cmovl %%r8d, %%r9d            \n"
      "cmovge %%r8d, %%eax           \n"
      "mov %%r9d, (%0)               \n"
      "cmp %%edx, %%ecx              \n"
      "mov %%edx, %%r8d              \n"
      "cmovl %%ecx, %%r8d            \n"
      "cmovge %%ecx, %%edx           \n"
      "mov %%edx, 0xc(%0)            \n"
      "cmp %%r8d, %%eax              \n"
      "mov %%r8d, %%ecx              \n"
      "cmovl %%eax, %%ecx            \n"
      "cmovge %%eax, %%r8d           \n"
      "mov %%r8d, 0x8(%0)            \n"
      "mov %%ecx, 0x4(%0)            \n"
      : "+r"(buffer)
      :
      : "eax", "ecx", "edx", "r8d", "r9d", "memory");
}

void Sort5AlphaDev(int* buffer) {
  asm volatile(
      "mov (%0), %%eax               \n"
      "mov 0x4(%0), %%ecx            \n"
      "cmp %%eax, %%ecx              \n"
      "mov %%eax, %%edx              \n"
      "cmovl %%ecx, %%edx            \n"
      "cmovg %%ecx, %%eax            \n"
      "mov 0xc(%0), %%r8d            \n"
      "mov 0x10(%0), %%ecx           \n"
      "cmp %%r8d, %%ecx              \n"
      "mov %%r8d, %%r9d              \n"
      "cmovl %%ecx, %%r9d            \n"
      "cmovg %%ecx, %%r8d            \n"
      "mov 0x8(%0), %%r10d           \n"
      "cmp %%r10d, %%r8d             \n"
      "mov %%r10d, %%ecx             \n"
      "cmovl %%r8d, %%ecx            \n"
      "cmovle %%r10d, %%r8d          \n"
      "cmp %%ecx, %%r9d              \n"
      "cmovle %%r9d, %%r10d          \n"
      "cmovg %%r9d, %%ecx            \n"
      "cmp %%eax, %%r8d              \n"
      "mov %%eax, %%r9d              \n"
      "cmovl %%r8d, %%r9d            \n"
      "cmovle %%eax, %%r8d           \n"
      "cmp %%edx, %%ecx              \n"
      "mov %%edx, %%eax              \n"
      "cmovl %%ecx, %%eax            \n"
      "cmovle %%edx, %%ecx           \n"
      "mov %%r8d, 0x10(%0)           \n"
      "cmp %%eax, %%r10d             \n"
      "cmovle %%r10d, %%edx          \n"
      "mov %%edx, (%0)               \n"
      "cmovg %%r10d, %%eax           \n"
      "cmp %%r9d, %%ecx              \n"
      "mov %%r9d, %%r8d              \n"
      "cmovl %%ecx, %%r8d            \n"
      "cmovle %%r9d, %%ecx           \n"
      "mov %%ecx, 0xc(%0)            \n"
      "cmp %%r8d, %%eax              \n"
      "cmovle %%eax, %%r9d           \n"
      "mov %%r9d, 0x4(%0)            \n"
      "cmovg %%eax, %%r8d            \n"
      "mov %%r8d, 0x8(%0)            \n"
      : "+r"(buffer)
      :
      : "eax", "ecx", "edx", "r8d", "r9d", "r10d", "memory");
}

void sort2(int *a, int *b){
    if (*a > *b){
        int tmp = *a;
        *a = *b;
        *b = tmp;
    }
}

void Sort3Naive(int *a){
    sort2(a, a + 1);
    sort2(a, a + 2);
    sort2(a + 1, a + 2);
}

void Sort4Naive(int *a){
    sort2(a, a + 1);
    sort2(a + 2, a + 3);
    sort2(a, a + 2);
    sort2(a + 1, a + 3);
    sort2(a + 1, a + 2);
}

void Sort5Naive(int *a){
    sort2(a, a + 3);
    sort2(a + 1, a + 4);
    sort2(a, a + 2);
    sort2(a + 1, a + 3);
    sort2(a, a + 1);
    sort2(a + 2, a + 4);
    sort2(a + 1, a + 2);
    sort2(a + 3, a + 4);
    sort2(a + 2, a + 3);
}

#define NUM_TESTS 10000
#define CONCAT2(a, b, c) a##b##c
#define CONCAT(a, b, c) CONCAT2(a, b, c)

uint64_t test_AlphaDev(){
    uint64_t total = 0;
    int checksum = 0;

    int a[N_SORT];
    for (int k = 0; k < NUM_TESTS; k++){
        for (int i = 0; i < N_SORT; i++){
            a[i] = rand();
        }

        uint64_t start = __builtin_readcyclecounter();

        CONCAT(Sort, N_SORT, AlphaDev)(a);

        uint64_t end = __builtin_readcyclecounter();

        // Check correctness
        for (int i = 1; i < N_SORT; i++){
            assert(a[i - 1] <= a[i]);
        }

        // Compute checksum to prevent compiler from optimizing away the loop
        checksum += a[0] + a[1] + a[2];

        total += end - start;
    }

    printf("Sort%dAlphaDev: %f cycles on average, checksum %d\n", N_SORT, (double)total / NUM_TESTS, checksum);

    return total;
}

uint64_t test_Naive(){
    uint64_t total = 0;
    int checksum = 0;
    int a[N_SORT];
    for (int k = 0; k < NUM_TESTS; k++){
        for (int i = 0; i < N_SORT; i++){
            a[i] = rand();
        }

        uint64_t start = __builtin_readcyclecounter();

        CONCAT(Sort, N_SORT, Naive)(a);

        uint64_t end = __builtin_readcyclecounter();

        // Check correctness
        for (int i = 1; i < N_SORT; i++){
            assert(a[i - 1] <= a[i]);
        }

        // Compute checksum to prevent compiler from optimizing away the loop
        checksum += a[0] + a[1] + a[2];

        total += end - start;
    }

    printf("Sort%dNaive: %f cycles on average, checksum %d\n", N_SORT, (double)total / NUM_TESTS, checksum);

    return total;
}

int main(){
    for (int k = 0; k < 100; k++){
        srand(0);
        uint64_t cycles_alphadev = test_AlphaDev();
        srand(0);
        uint64_t cycles_naive = test_Naive();

        double percent = cycles_naive * 100.0 / cycles_alphadev - 100.0;

        printf("AlphaDev is %f %% faster than a naive sorting network\n\n", percent);
    }

    return 0;
}
SuperSodaSea commented 1 year ago

I changed your code to pre-generate a large array and then sort all the data at once, which gave a completely different result (AlphaDev is now significantly faster). I think it should be related to caching.

New Result

Sort5AlphaDev: 9.321600 cycles on average, checksum -1224191569
Sort5Naive: 68.231400 cycles on average, checksum -1224191569
AlphaDev is 631.970906 % faster than a naive sorting network

Sort5AlphaDev: 9.330000 cycles on average, checksum -1224191569
Sort5Naive: 68.377600 cycles on average, checksum -1224191569
AlphaDev is 632.878885 % faster than a naive sorting network

Sort5AlphaDev: 9.325800 cycles on average, checksum -1224191569
Sort5Naive: 70.886600 cycles on average, checksum -1224191569
AlphaDev is 660.112805 % faster than a naive sorting network

New Code

uint64_t test_AlphaDev(){
    uint64_t total = 0;
    int checksum = 0;

    int a[NUM_TESTS][N_SORT];
    for (int k = 0; k < NUM_TESTS; k++){
        for (int i = 0; i < N_SORT; i++){
            a[k][i] = rand();
        }
    }

    uint64_t start = __builtin_readcyclecounter();

    for (int k = 0; k < NUM_TESTS; k++){
        CONCAT(Sort, N_SORT, AlphaDev)(a[k]);
    }

    uint64_t end = __builtin_readcyclecounter();

    for (int k = 0; k < NUM_TESTS; k++){
        // Check correctness
        for (int i = 1; i < N_SORT; i++){
            assert(a[k][i - 1] <= a[k][i]);
        }

        // Compute checksum to prevent compiler from optimizing away the loop
        checksum += a[k][0] + a[k][1] + a[k][2];
    }

    total += end - start;

    printf("Sort%dAlphaDev: %f cycles on average, checksum %d\n", N_SORT, (double)total / NUM_TESTS, checksum);

    return total;
}

// Same change in test_Naive, emitted
cclauss commented 1 year ago

Perhaps test_AlphaDev() --> test_AlphaDev(number_of_sorts, number_of_tests) to make tests easier to tweak.

Does the upper limit on number_of_sorts need to be five? Could it be 5,000 instead?

99991 commented 1 year ago

Does the upper limit on number_of_sorts need to be five? Could it be 5,000 instead?

It could be up to 8.

https://github.com/deepmind/alphadev/blob/1a6eac1544c1075ef27814f09f9223cb84537c59/sort_functions_test.cc#L280

I guess higher numbers would be more difficult due to limited number of registers.

99991 commented 1 year ago

I changed your code to pre-generate a large array and then sort all the data at once, which gave a completely different result (AlphaDev is now significantly faster). I think it should be related to caching.

I can reproduce this, but I think the reason is that Clang generates bad code with many conditional jumps for this configuration. Generated assembly code at https://gcc.godbolt.org/z/Tq1M53WY9

branchy

Fortunately, I was able to coax Clang into generating cmov instructions by loading the values from the array into local variables first. This seems to generate much better code.

cmov

static inline void sort2(int *a, int *b){
    if (*a > *b){
        int tmp = *a;
        *a = *b;
        *b = tmp;
    }
}

void Sort3NaiveC(int *a){
    int a0 = a[0];
    int a1 = a[1];
    int a2 = a[2];

    sort2(&a0, &a1);
    sort2(&a1, &a2);
    sort2(&a0, &a1);

    a[0] = a0;
    a[1] = a1;
    a[2] = a2;
}

Perhaps test_AlphaDev() --> test_AlphaDev(number_of_sorts, number_of_tests) to make tests easier to tweak.

The number of values to sort should be constant, so that the compiler can do its best. But I can make the other parameters configurable.

I also implemented the sorting network in assembly, but it was no good because I always used the same register for swapping two values, so the CPU can not pipeline as deeply. Clang on the other hand can generate code which is quite good. I reran the experiments a few times to compute more accurate performance numbers.

Benchmark results for sorting many arrays of length 3/4/5

N = 3
AlphaDev: 4.39 ± 0.27 cycles
NaiveAsm: 5.03 ± 0.06 cycles
NaiveC  : 4.01 ± 0.12 cycles
AlphaDev is 14.58 % faster than a naive sorting network written in assembly.
AlphaDev is -8.67 % faster than a naive sorting network written in C.

N = 4
AlphaDev: 4.42 ± 0.28 cycles
NaiveAsm: 5.08 ± 0.42 cycles
NaiveC  : 4.04 ± 0.16 cycles
AlphaDev is 15.02 % faster than a naive sorting network written in assembly.
AlphaDev is -8.64 % faster than a naive sorting network written in C.

N = 5
AlphaDev: 4.44 ± 0.30 cycles
NaiveAsm: 5.02 ± 0.15 cycles
NaiveC  : 4.03 ± 0.15 cycles
AlphaDev is 13.21 % faster than a naive sorting network written in assembly.
AlphaDev is -9.24 % faster than a naive sorting network written in C.

Benchmark code

// clang -O3 main.c -o main -lm && ./main
#include <stdio.h>
#include <stdint.h>
#include <assert.h>
#include <stdlib.h>
#include <math.h>

// How often to sort 3/4/5 values
const int num_sorts = 100;
// Repeat this test many times to get a better average
const int num_repeats = 100;

void Sort3AlphaDev(int* buffer) {
  asm volatile(
      "mov 0x4(%0), %%eax            \n"
      "mov 0x8(%0), %%ecx            \n"
      "cmp %%eax, %%ecx              \n"
      "mov %%eax, %%edx              \n"
      "cmovl %%ecx, %%edx            \n"
      "mov (%0), %%r8d               \n"
      "cmovg %%ecx, %%eax            \n"
      "cmp %%r8d, %%eax              \n"
      "mov %%r8d, %%ecx              \n"
      "cmovl %%eax, %%ecx            \n"
      "cmovle %%r8d, %%eax           \n"
      "mov %%eax, 0x8(%0)            \n"
      "cmp %%ecx, %%edx              \n"
      "cmovle %%edx, %%r8d           \n"
      "mov %%r8d, (%0)               \n"
      "cmovg %%edx, %%ecx            \n"
      "mov %%ecx, 0x4(%0)            \n"
      : "+r"(buffer)
      :
      : "eax", "ecx", "edx", "r8d", "memory");
}

void Sort4AlphaDev(int* buffer) {
  asm volatile(
      "mov 0x8(%0), %%eax            \n"
      "mov (%0), %%ecx               \n"
      "mov 0x4(%0), %%edx            \n"
      "cmp %%eax, %%ecx              \n"
      "mov %%eax, %%r8d              \n"
      "cmovl %%ecx, %%r8d            \n"
      "cmovl %%eax, %%ecx            \n"
      "mov 0xc(%0), %%r9d            \n"
      "cmp %%r9d, %%edx              \n"
      "mov %%r9d, %%eax              \n"
      "cmovl %%edx, %%eax            \n"
      "cmovl %%r9d, %%edx            \n"
      "cmp %%eax, %%r8d              \n"
      "mov %%eax, %%r9d              \n"
      "cmovl %%r8d, %%r9d            \n"
      "cmovge %%r8d, %%eax           \n"
      "mov %%r9d, (%0)               \n"
      "cmp %%edx, %%ecx              \n"
      "mov %%edx, %%r8d              \n"
      "cmovl %%ecx, %%r8d            \n"
      "cmovge %%ecx, %%edx           \n"
      "mov %%edx, 0xc(%0)            \n"
      "cmp %%r8d, %%eax              \n"
      "mov %%r8d, %%ecx              \n"
      "cmovl %%eax, %%ecx            \n"
      "cmovge %%eax, %%r8d           \n"
      "mov %%r8d, 0x8(%0)            \n"
      "mov %%ecx, 0x4(%0)            \n"
      : "+r"(buffer)
      :
      : "eax", "ecx", "edx", "r8d", "r9d", "memory");
}

void Sort5AlphaDev(int* buffer) {
  asm volatile(
      "mov (%0), %%eax               \n"
      "mov 0x4(%0), %%ecx            \n"
      "cmp %%eax, %%ecx              \n"
      "mov %%eax, %%edx              \n"
      "cmovl %%ecx, %%edx            \n"
      "cmovg %%ecx, %%eax            \n"
      "mov 0xc(%0), %%r8d            \n"
      "mov 0x10(%0), %%ecx           \n"
      "cmp %%r8d, %%ecx              \n"
      "mov %%r8d, %%r9d              \n"
      "cmovl %%ecx, %%r9d            \n"
      "cmovg %%ecx, %%r8d            \n"
      "mov 0x8(%0), %%r10d           \n"
      "cmp %%r10d, %%r8d             \n"
      "mov %%r10d, %%ecx             \n"
      "cmovl %%r8d, %%ecx            \n"
      "cmovle %%r10d, %%r8d          \n"
      "cmp %%ecx, %%r9d              \n"
      "cmovle %%r9d, %%r10d          \n"
      "cmovg %%r9d, %%ecx            \n"
      "cmp %%eax, %%r8d              \n"
      "mov %%eax, %%r9d              \n"
      "cmovl %%r8d, %%r9d            \n"
      "cmovle %%eax, %%r8d           \n"
      "cmp %%edx, %%ecx              \n"
      "mov %%edx, %%eax              \n"
      "cmovl %%ecx, %%eax            \n"
      "cmovle %%edx, %%ecx           \n"
      "mov %%r8d, 0x10(%0)           \n"
      "cmp %%eax, %%r10d             \n"
      "cmovle %%r10d, %%edx          \n"
      "mov %%edx, (%0)               \n"
      "cmovg %%r10d, %%eax           \n"
      "cmp %%r9d, %%ecx              \n"
      "mov %%r9d, %%r8d              \n"
      "cmovl %%ecx, %%r8d            \n"
      "cmovle %%r9d, %%ecx           \n"
      "mov %%ecx, 0xc(%0)            \n"
      "cmp %%r8d, %%eax              \n"
      "cmovle %%eax, %%r9d           \n"
      "mov %%r9d, 0x4(%0)            \n"
      "cmovg %%eax, %%r8d            \n"
      "mov %%r8d, 0x8(%0)            \n"
      : "+r"(buffer)
      :
      : "eax", "ecx", "edx", "r8d", "r9d", "r10d", "memory");
}

static inline void sort2(int *a, int *b){
    if (*a > *b){
        int tmp = *a;
        *a = *b;
        *b = tmp;
    }
}

void Sort3NaiveC(int *a){
    int a0 = a[0];
    int a1 = a[1];
    int a2 = a[2];

    sort2(&a0, &a1);
    sort2(&a1, &a2);
    sort2(&a0, &a1);

    a[0] = a0;
    a[1] = a1;
    a[2] = a2;
}

void Sort4NaiveC(int *a){
    int a0 = a[0];
    int a1 = a[1];
    int a2 = a[2];
    int a3 = a[3];

    sort2(&a0, &a1);
    sort2(&a2, &a3);
    sort2(&a0, &a2);
    sort2(&a1, &a3);
    sort2(&a1, &a2);

    a[0] = a0;
    a[1] = a1;
    a[2] = a2;
    a[3] = a3;
}

void Sort5NaiveC(int *a){
    int a0 = a[0];
    int a1 = a[1];
    int a2 = a[2];
    int a3 = a[3];
    int a4 = a[4];

    sort2(&a0, &a3);
    sort2(&a1, &a4);
    sort2(&a0, &a2);
    sort2(&a1, &a3);
    sort2(&a0, &a1);
    sort2(&a2, &a4);
    sort2(&a1, &a2);
    sort2(&a3, &a4);
    sort2(&a2, &a3);

    a[0] = a0;
    a[1] = a1;
    a[2] = a2;
    a[3] = a3;
    a[4] = a4;
}

void Sort3NaiveAsm(int *a){
    asm volatile(
        // move a[0], a[1], a[2] to registers eax, ecx, edx
        "mov (%0), %%eax               \n"
        "mov 0x4(%0), %%ecx            \n"
        "mov 0x8(%0), %%edx            \n"

        // sort eax and ecx (a[0] and a[1])
        "cmp %%eax, %%ecx              \n"
        "cmovl %%ecx, %%r8d            \n"
        "cmovl %%eax, %%ecx            \n"
        "cmovl %%r8d, %%eax            \n"

        // sort eax and edx (a[0] and a[2])
        "cmp %%eax, %%edx              \n"
        "cmovl %%edx, %%r8d            \n"
        "cmovl %%eax, %%edx            \n"
        "cmovl %%r8d, %%eax            \n"

        // sort ecx and edx (a[1] and a[2])
        "cmp %%ecx, %%edx              \n"
        "cmovl %%edx, %%r8d            \n"
        "cmovl %%ecx, %%edx            \n"
        "cmovl %%r8d, %%ecx            \n"

        // write back to memory
        "mov %%eax, (%0)               \n"
        "mov %%ecx, 0x4(%0)            \n"
        "mov %%edx, 0x8(%0)            \n"
        : "+r"(a)
        :
        : "eax", "ecx", "edx", "r8d", "memory");
}

void Sort4NaiveAsm(int *a){
    asm volatile(
        // move a[0], a[1], a[2], a[3] to registers eax, ecx, edx, r8d
        "mov (%0), %%eax               \n"
        "mov 0x4(%0), %%ecx            \n"
        "mov 0x8(%0), %%edx            \n"
        "mov 0xc(%0), %%r8d            \n"

        // sort eax and ecx (a[0] and a[1])
        "cmp %%eax, %%ecx              \n"
        "cmovl %%ecx, %%r9d            \n"
        "cmovl %%eax, %%ecx            \n"
        "cmovl %%r9d, %%eax            \n"

        // sort edx and r8d (a[2] and a[3])
        "cmp %%edx, %%r8d              \n"
        "cmovl %%r8d, %%r9d           \n"
        "cmovl %%edx, %%r8d            \n"
        "cmovl %%r9d, %%edx           \n"

        // sort eax and edx (a[0] and a[2])
        "cmp %%eax, %%edx              \n"
        "cmovl %%edx, %%r9d            \n"
        "cmovl %%eax, %%edx            \n"
        "cmovl %%r9d, %%eax            \n"

        // sort ecx and r8d (a[1] and a[3])
        "cmp %%ecx, %%r8d              \n"
        "cmovl %%r8d, %%r9d           \n"
        "cmovl %%ecx, %%r8d            \n"
        "cmovl %%r9d, %%ecx           \n"

        // sort ecx and edx (a[1] and a[2])
        "cmp %%ecx, %%edx              \n"
        "cmovl %%edx, %%r9d            \n"
        "cmovl %%ecx, %%edx            \n"
        "cmovl %%r9d, %%ecx            \n"

        // write back to memory
        "mov %%eax, (%0)               \n"
        "mov %%ecx, 0x4(%0)            \n"
        "mov %%edx, 0x8(%0)            \n"
        "mov %%r8d, 0xc(%0)            \n"

        : "+r"(a)
        :
        : "eax", "ecx", "edx", "r8d", "r9d", "memory"
    );
}

void Sort5NaiveAsm(int *a){
    asm volatile(
        // move a[0], a[1], a[2], a[3], a[4] to registers eax, ecx, edx, r8d, r9d
        "mov (%0), %%eax               \n"
        "mov 0x4(%0), %%ecx            \n"
        "mov 0x8(%0), %%edx            \n"
        "mov 0xc(%0), %%r8d            \n"
        "mov 0x10(%0), %%r9d           \n"

        // sort eax and r8d (a[0] and a[3])
        "cmp %%eax, %%r8d              \n"
        "cmovl %%r8d, %%r10d           \n"
        "cmovl %%eax, %%r8d            \n"
        "cmovl %%r10d, %%eax           \n"

        // sort ecx and r9d (a[1] and a[4])
        "cmp %%ecx, %%r9d              \n"
        "cmovl %%r9d, %%r10d           \n"
        "cmovl %%ecx, %%r9d            \n"
        "cmovl %%r10d, %%ecx           \n"

        // sort eax and edx (a[0] and a[2])
        "cmp %%eax, %%edx              \n"
        "cmovl %%edx, %%r10d           \n"
        "cmovl %%eax, %%edx            \n"
        "cmovl %%r10d, %%eax           \n"

        // sort ecx and r8d (a[1] and a[3])
        "cmp %%ecx, %%r8d              \n"
        "cmovl %%r8d, %%r10d           \n"
        "cmovl %%ecx, %%r8d            \n"
        "cmovl %%r10d, %%ecx           \n"

        // sort eax and ecx (a[0] and a[1])
        "cmp %%eax, %%ecx              \n"
        "cmovl %%ecx, %%r10d           \n"
        "cmovl %%eax, %%ecx            \n"
        "cmovl %%r10d, %%eax           \n"

        // sort edx and r9d (a[2] and a[4])
        "cmp %%edx, %%r9d              \n"
        "cmovl %%r9d, %%r10d           \n"
        "cmovl %%edx, %%r9d            \n"
        "cmovl %%r10d, %%edx           \n"

        // sort ecx and edx (a[1] and a[2])
        "cmp %%ecx, %%edx              \n"
        "cmovl %%edx, %%r10d           \n"
        "cmovl %%ecx, %%edx            \n"
        "cmovl %%r10d, %%ecx           \n"

        // sort r8d and r9d (a[3] and a[4])
        "cmp %%r8d, %%r9d              \n"
        "cmovl %%r9d, %%r10d           \n"
        "cmovl %%r8d, %%r9d            \n"
        "cmovl %%r10d, %%r8d           \n"

        // sort edx and r8d (a[2] and a[3])
        "cmp %%edx, %%r8d              \n"
        "cmovl %%r8d, %%r10d           \n"
        "cmovl %%edx, %%r8d            \n"
        "cmovl %%r10d, %%edx           \n"

        // write back to memory
        "mov %%eax, (%0)               \n"
        "mov %%ecx, 0x4(%0)            \n"
        "mov %%edx, 0x8(%0)            \n"
        "mov %%r8d, 0xc(%0)            \n"
        "mov %%r9d, 0x10(%0)           \n"

        : "+r"(a)
        :
        : "eax", "ecx", "edx", "r8d", "r9d", "r10d", "memory"
    );
}

#define TEST(METHOD, N) \
uint64_t test_##METHOD##N(int num_sorts){\
    int checksum = 0;\
\
    int *a = malloc(num_sorts * N * sizeof(int));\
\
    srand(0);\
    for (int k = 0; k < num_sorts; k++){\
        for (int i = 0; i < N; i++){\
            a[k * N + i] = rand();\
        }\
    }\
\
    uint64_t start = __builtin_readcyclecounter();\
\
    for (int k = 0; k < num_sorts; k++){\
        Sort##N##METHOD(&a[k * N]);\
    }\
\
    uint64_t end = __builtin_readcyclecounter();\
\
    for (int k = 0; k < num_sorts; k++){\
        /* Check correctness */\
        for (int i = 1; i < N; i++){\
            assert(a[k * N + (i - 1)] <= a[k * N + i]);\
        }\
\
        /* Compute checksum to prevent compiler from optimizing away the loop */\
        for (int i = 0; i < N; i++){\
            checksum += a[k * N + i];\
        }\
    }\
\
    free(a);\
    return end - start;\
}

TEST(AlphaDev, 3)
TEST(AlphaDev, 4)
TEST(AlphaDev, 5)
TEST(NaiveC, 3)
TEST(NaiveC, 4)
TEST(NaiveC, 5)
TEST(NaiveAsm, 3)
TEST(NaiveAsm, 4)
TEST(NaiveAsm, 5)

double mean(const double *values, int n){
    double sum = 0;
    for (int k = 0; k < n; k++){
        sum += values[k];
    }
    return sum / n;
}

double std(const double *values, int n){
    double m = mean(values, n);
    double sum = 0;
    for (int k = 0; k < n; k++){
        sum += (values[k] - m) * (values[k] - m);
    }
    return sqrt(sum / n);
}

#define TEST_N(N_SORT) \
void test_##N_SORT(int num_sorts, int num_repeats){\
    double percent;\
    double cycles_alphadev[num_repeats];\
    double cycles_naiveasm[num_repeats];\
    double cycles_naive[num_repeats];\
\
    for (int k = 0; k < num_repeats; k++){\
        cycles_alphadev[k] = test_AlphaDev3(num_sorts) / (double)num_sorts;\
        cycles_naiveasm[k] = test_NaiveAsm3(num_sorts) / (double)num_sorts;\
        cycles_naive[k] = test_NaiveC3(num_sorts) / (double)num_sorts;\
    }\
\
    printf("N = %d\n", N_SORT);\
    printf("AlphaDev: %.2f ± %.2f cycles\n", mean(cycles_alphadev, num_repeats), std(cycles_alphadev, num_repeats));\
    printf("NaiveAsm: %.2f ± %.2f cycles\n", mean(cycles_naiveasm, num_repeats), std(cycles_naiveasm, num_repeats));\
    printf("NaiveC  : %.2f ± %.2f cycles\n", mean(cycles_naive, num_repeats), std(cycles_naive, num_repeats));\
\
    percent = mean(cycles_naiveasm, num_repeats) * 100.0 / mean(cycles_alphadev, num_repeats) - 100.0;\
    printf("AlphaDev is %.2f %% faster than a naive sorting network written in assembly.\n", percent);\
    percent = mean(cycles_naive, num_repeats) * 100.0 / mean(cycles_alphadev, num_repeats) - 100.0;\
    printf("AlphaDev is %.2f %% faster than a naive sorting network written in C.\n\n", percent);\
}

TEST_N(3)
TEST_N(4)
TEST_N(5)

int main(){
    test_3(num_sorts, num_repeats);
    test_4(num_sorts, num_repeats);
    test_5(num_sorts, num_repeats);

    return 0;
}