flintlib / flint

FLINT (Fast Library for Number Theory)
http://www.flintlib.org
GNU Lesser General Public License v3.0
445 stars 244 forks source link

Speed up basecase matrix multiplication, polynomial multiplication, etc. #1508

Open fredrik-johansson opened 1 year ago

fredrik-johansson commented 1 year ago

The basecases for functions like nmod_mat_mul, fmpz_mat_mul, nmod_poly_mul and fmpz_poly_mul can be sped up significantly for half-word-size entries (and maybe bigger entries with more effort) by writing vectorization-friendly code and compiling with -march=native -O3.

Some naive code just for illustration:

#include "nmod.h"
#include "nmod_mat.h"
#include "profiler.h"
#include "cblas.h"

#define BITS 20
#define N 64

void
nmod_mmul_ulong(mp_ptr res, mp_srcptr a, mp_srcptr b, nmod_t mod)
{
    slong i, j, k;
    ulong s0, s1, s2, s3;

    for (i = 0; i < N; i++)
    {
        for (j = 0; j < N; j++)
        {
            s0 = s1 = s2 = s3 = 0;

            for (k = 0; k < N; k += 4)
            {
                s0 += a[N * i + k + 0] * b[N * j + k + 0];
                s1 += a[N * i + k + 1] * b[N * j + k + 1];
                s2 += a[N * i + k + 2] * b[N * j + k + 2];
                s3 += a[N * i + k + 3] * b[N * j + k + 3];
            }

#if 0
            s0 += s1;
            s2 += s3;
            s0 += s2;

            res[N * i + j] = nmod_set_ui(s0, mod);
#else
            /* if we want more bits */

            add_ssaaaa(s1, s0, 0, s0, 0, s1);
            add_ssaaaa(s3, s2, 0, s2, 0, s3);
            add_ssaaaa(s1, s0, s1, s0, s3, s2);

            NMOD2_RED2(s0, s1, s0, mod);
            res[N * i + j] = s0;
#endif
        }
    }
}

void
nmod_mmul_double(mp_ptr res, const double * a, const double * b, nmod_t mod)
{
    slong i, j, k;
    double s0, s1, s2, s3;

    for (i = 0; i < N; i++)
    {
        for (j = 0; j < N; j++)
        {
            s0 = s1 = s2 = s3 = 0;

            for (k = 0; k < N; k += 4)
            {
                s0 += a[N * i + k + 0] * b[N * j + k + 0];
                s1 += a[N * i + k + 1] * b[N * j + k + 1];
                s2 += a[N * i + k + 2] * b[N * j + k + 2];
                s3 += a[N * i + k + 3] * b[N * j + k + 3];
            }

            s0 += s1;
            s2 += s3;
            s0 += s2;

            res[N * i + j] = nmod_set_ui((ulong) s0, mod);
        }
    }
}

int main()
{
    mp_ptr A, B, C;
    double * D, *E, *F;
    ulong p = (UWORD(1) << BITS) - 1;

    nmod_t mod;
    nmod_init(&mod, p);

    nmod_mat_t X, Y, Z;

    flint_printf("bits = %wd, N = %wd\n\n", BITS, N);

    nmod_mat_init(X, N, N, mod.n);
    nmod_mat_init(Y, N, N, mod.n);
    nmod_mat_init(Z, N, N, mod.n);

    A = flint_malloc(sizeof(mp_limb_t) * N * N);
    B = flint_malloc(sizeof(mp_limb_t) * N * N);
    C = flint_malloc(sizeof(mp_limb_t) * N * N);
    D = flint_malloc(sizeof(mp_limb_t) * N * N);
    E = flint_malloc(sizeof(mp_limb_t) * N * N);
    F = flint_malloc(sizeof(mp_limb_t) * N * N);

    slong i, j;

    flint_rand_t state;
    flint_randinit(state);

    nmod_mat_randfull(X, state);
    nmod_mat_randfull(Y, state);
    nmod_mat_randfull(Z, state);

    for (i = 0; i < N; i++)
        for (j = 0; j < N; j++)
        {
            A[i * N + j] = X->rows[i][j];
            B[j * N + i] = Y->rows[i][j];

            D[i * N + j] = X->rows[i][j];
            E[j * N + i] = Y->rows[i][j];
        }

    flint_printf("\nnmod_mat_mul:\n");
    TIMEIT_START
    nmod_mat_mul(Z, X, Y);
    TIMEIT_STOP

    flint_printf("\nulong:\n");
    TIMEIT_START
    nmod_mmul_ulong(C, A, B, mod);
    TIMEIT_STOP

    for (i = 0; i < N; i++)
        for (j = 0; j < N; j++)
            if (Z->rows[i][j] != C[i * N + j])
                { flint_printf("WRONG RESULT\n"); goto cl0; }
    cl0:

    flint_printf("\ndouble:\n");
    TIMEIT_START
    nmod_mmul_double(C, D, E, mod);
    TIMEIT_STOP

    for (i = 0; i < N; i++)
        for (j = 0; j < N; j++)
            if (Z->rows[i][j] != C[i * N + j])
                { flint_printf("WRONG RESULT\n"); goto cl1; }
    cl1:

    flint_printf("\nBLAS:\n");
    TIMEIT_START
    cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, N, N, N, 1.0, D, N, E, N, 0.0, F, N);
    for (i = 0; i < N; i++)
        for (j = 0; j < N; j++)
            C[N * i + j] = nmod_set_ui((ulong) F[i * N + j], mod);
    TIMEIT_STOP

    for (i = 0; i < N; i++)
        for (j = 0; j < N; j++)
            if (Z->rows[i][j] != C[i * N + j])
                { flint_printf("WRONG RESULT\n"); goto cl2; }
    cl2:
}

Sample results on my machine: at least in some cases, we can get close to a factor 2 speedup with vectorized ulong math, and a factor 4 speedup with double. This ignores conversion and transposition costs, which may not be negligible.

bits = 20, N = 64

nmod_mat_mul:
cpu/wall(s): 0.000104 0.000105

ulong:
cpu/wall(s): 7.28e-05 7.29e-05

double:
cpu/wall(s): 2.48e-05 2.47e-05

BLAS:
cpu/wall(s): 1.89e-05 1.88e-05

bits = 30, N = 64

nmod_mat_mul:
cpu/wall(s): 0.000129 0.000129

ulong:
cpu/wall(s): 7.23e-05 7.23e-05
edgarcosta commented 1 year ago

I would expect the difference to be minimal without unrolling the for loop on k, as I would expect the compiler to already vectorize that without a problem. Or am I missing something?

fredrik-johansson commented 1 year ago

This is just an example; there are lots of permutations of manual unrolling and blocking worth trying. It's not at all clear a priori what works best, and I doubt that the compiler knows either.

Unrolling into separate sums is useful for moduli approaching 32 or 26 bits because we can support larger N without requiring modular reduction or spilling over to a double word. Actually my original motivation for this test was to investigate various strategies for moduli around 30 bits.

vneiger commented 1 year ago

Here are benchmarks for matrix multiplication and Gaussian elimination on word-size prime fields, using FLINT (with / without BLAS), using NTL, and using FFLAS-FFPACK (several implementations depending on size of field; however I did not call the "multi-precision'' variants so this is limited to about 30 bit primes).

For matrix multiplication I included square and rectangular cases, including matrix-vector and vector-matrix products. For Gaussian elimination I also used several dimension profiles, as well as a varying rank to see how good the implementations are in terms of rank sensitivity.

All of this was on a computer with AVX-512, see the beginning of the files for more details on the machine. I'll do the same measurements on a laptop without AVX-512 soon.

This is raw data, which is not so readable, but we can already see a few trends. I'll try to extract some plots out of this, to get a better view of the main points where the different libraries' performances diverge. Examples of visible trends:

Warning: small typo, in the matmul the columns should be rdim idim cdim (row, inner, column dimensions) instead of rank rdim cdim.

matmul_lu_withavx512.zip

vneiger commented 1 year ago

Addition: similar benchmarks, but on some processor without AVX512. At first sight, some observations above seem to remain valid. Surprisingly this seems to indicate that (on this machine) compiling FLINT with BLAS does not provide much advantage for matrix multiplication, when the field prime is beyond 20 bits. matmul_lu_noavx512.zip