flintlib / flint

FLINT (Fast Library for Number Theory)
http://www.flintlib.org
GNU Lesser General Public License v3.0
430 stars 243 forks source link

Slow modular arithmetic #1010

Open fredrik-johansson opened 2 years ago

fredrik-johansson commented 2 years ago

Here is a simple function (bernsum_powg) taken from David Harvey's bernmm library (using NTL), and two FLINT versions using the n_precomp and n_preinvert interfaces.

#include <NTL/ZZ.h>
#include "flint/ulong_extras.h"
#include "flint/profiler.h"

NTL_CLIENT;

long PowerMod(long a, long ee, long n, mulmod_t ninv)
{
   long x, y;

   unsigned long e;

   if (ee < 0)
      e = - ((unsigned long) ee);
   else
      e = ee;

   x = 1;
   y = a;
   while (e) {
      if (e & 1) x = MulMod(x, y, n, ninv);
      y = MulMod(y, y, n, ninv);
      e = e >> 1;
   }

   if (ee < 0) x = InvMod(x, n);

   return x;
}

long bernsum_powg(long p, mulmod_t pinv, long k, long g)
{
   long half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2;    // (g-1)/2 mod p
   long g_to_jm1 = 1;
   long g_to_km1 = PowerMod(g, k-1, p, pinv);
   long g_to_km1_to_j = g_to_km1;
   long sum = 0;
   muldivrem_t g_pinv = PrepMulDivRem(g, p);
   mulmod_precon_t g_to_km1_pinv = PrepMulModPrecon(g_to_km1, p, pinv);

   for (long j = 1; j <= (p-1)/2; j++)
   {
      // at this point,
      //    g_to_jm1 holds g^(j-1) mod p
      //    g_to_km1_to_j holds (g^(k-1))^j mod p

      // update g_to_jm1 and compute q = (g*(g^(j-1) mod p) - (g^j mod p)) / p
      long q;
      g_to_jm1 = MulDivRem(q, g_to_jm1, g, p, g_pinv);

      // compute h = -h_g(g^j) = q - (g-1)/2
      long h = SubMod(q, half_gm1, p);

      // add h_g(g^j) * (g^(k-1))^j to running total
      sum = SubMod(sum, MulMod(h, g_to_km1_to_j, p, pinv), p);

      // update g_to_km1_to_j
      g_to_km1_to_j = MulModPrecon(g_to_km1_to_j, g_to_km1, p, g_to_km1_pinv);
   }

   return sum;
}

long bernsum_powg_flint2(ulong p, double pinv, ulong k, ulong g)
{
   ulong half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2;    // (g-1)/2 mod p
   ulong g_to_jm1 = 1;
   ulong g_to_km1 = n_powmod_precomp(g, k-1, p, pinv);
   ulong g_to_km1_to_j = g_to_km1;
   ulong sum = 0;
   ulong g_to_km1_pinv = n_mulmod_precomp_shoup(g_to_km1, p);

   for (long j = 1; j <= (p-1)/2; j++)
   {
      ulong q;

      g_to_jm1 = n_divrem2_precomp(&q, g_to_jm1 * g, p, pinv);
      ulong h = n_submod(q, half_gm1, p);
      sum = n_submod(sum, n_mulmod_precomp(h, g_to_km1_to_j, p, pinv), p);
      g_to_km1_to_j = n_mulmod_shoup(g_to_km1, g_to_km1_to_j, g_to_km1_pinv, p);
   }

   return sum;
}

long bernsum_powg_flint2b(ulong p, ulong pinv, ulong k, ulong g)
{
   ulong half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2;    // (g-1)/2 mod p
   ulong g_to_jm1 = 1;
   ulong g_to_km1 = n_powmod2_preinv(g, k-1, p, pinv);
   ulong g_to_km1_to_j = g_to_km1;
   ulong sum = 0;

   for (long j = 1; j <= (p-1)/2; j++)
   {
      ulong q;

      g_to_jm1 = n_divrem2_preinv(&q, g_to_jm1 * g, p, pinv);
      ulong h = n_submod(q, half_gm1, p);
      sum = n_submod(sum, n_mulmod2_preinv(h, g_to_km1_to_j, p, pinv), p);
      g_to_km1_to_j = n_mulmod2_preinv(g_to_km1_to_j, g_to_km1, p, pinv);
   }

   return sum;
}

int main()
{
    long s, v;

    s = 0;
    TIMEIT_START
    s |= bernsum_powg(10007, PrepMulMod(10007), 9406, 5);
    TIMEIT_STOP
    printf("%ld\n", s);

    s = 0;
    TIMEIT_START
    s |= bernsum_powg_flint2(10007, n_precompute_inverse(10007), 9406, 5);
    TIMEIT_STOP
    printf("%ld\n", s);

    s = 0;
    TIMEIT_START
    s |= bernsum_powg_flint2b(10007, n_preinvert_limb(10007), 9406, 5);
    TIMEIT_STOP
    printf("%ld\n", s);
}

Timings on my machine:

cpu/wall(s): 4.19e-05 4.19e-05
5444
cpu/wall(s): 6.7e-05 6.7e-05
5444
cpu/wall(s): 0.000131 0.000131
5444

So our n_precomp arithmetic is 1.6x slower than NTL, and our n_preinvert arithmetic is 3x slower. I even cheated here -- NTL has a muldivrem function which we don't, so I put in a plain multiplication which of course will overflow if p is large.

wbhart commented 2 years ago

Very surprising. I can't imagine for the life of me what we could have overlooked there. So much care went into making that efficient. Is NTL using doubles?

wbhart commented 2 years ago

We could do the following, where they make sense (conversion costs/representation need to be considered):

fredrik-johansson commented 2 years ago

If n is odd and small enough and we use a balanced representation -n/2...n/2, this does a correctly reduced multiplication:

double dmod_mul(double a, double b, double n, double ninv)
{
    double magic = 6755399441055744.0;
    double r = a * b;

    return r - ((r * ninv + magic) - magic) * n;
}

This should be good for doing lots of multiplications in parallel with SIMD. How quickly can we add in this representation?

fredrik-johansson commented 2 years ago
fredrik-johansson commented 2 years ago

Another idea: in multimodular algorithms, we generally use primes of the form 2^n + c or 2^n - c where c is small. For multi-word moduli, this can certainly be exploited, but what about nmods?

fredrik-johansson commented 2 years ago

I have at least part of the answer: n_mulmod_preinv requires the inputs to be reduced and does something fast, n_mulmod2_preinv does not require the inputs to be reduced and does something slow.

Our nmod_mul is stupidly doing the same thing as n_mulmod2_preinv. We should basically just change it to do an n_mulmod_preinv instead; this is 2x faster on my machine.

We can also replace many other uses of n_mulmod2_preinv with n_mulmod_preinv throughout Flint.

Some shifts can be avoided when the modulus has exactly FLINT_BITS bits; maybe this is worth optimizing for in various places.

fredrik-johansson commented 2 years ago

Ditto for nmod_addmul / NMOD_ADDMUL.

tthsqe12 commented 2 years ago

If n is odd and small enough and we use a balanced representation -n/2...n/2, this does a correctly reduced multiplication

how small is small enough? This is assuming no fmadd/fmsub? What about with fmadd/fmsub?

fredrik-johansson commented 2 years ago

If n is odd and small enough and we use a balanced representation -n/2...n/2, this does a correctly reduced multiplication

how small is small enough? This is assuming no fmadd/fmsub? What about with fmadd/fmsub?

Up to sqrt(2^53) I guess, but I did not check or prove this. You might want to design an entirely different algorithm around fma.

fredrik-johansson commented 9 months ago

FWIW, flint does better if one uses nmod_mul and n_mulmod_shoup.

After that, the remaining difference seems to be due to n_divrem2_preinv being much slower than MulDivRem.

fredrik-johansson commented 4 months ago

Several functions for modular arithmetic like n_mulmod2_preinv, n_mod2_preinv, n_ll_mod_preinv, n_lll_mod_preinv don't take a norm as input and therefore need an flint_clz operation which is redundant in situations where we already have an nmod_t containing this data. There are also probably many places (though not all) where these operations should actually be inlined. Replacing them with nmod_mul, NMOD2_RED2 etc. would be an improvement.

We should think about ways to redesign these interfaces so that they are more obvious.

albinahlback commented 4 months ago

It would be nice to separate those that need normalization and those who does not. Not sure how to do that user friendly though.

edgarcosta commented 4 months ago

another point of comparison should be https://math.mit.edu/~drew/ffpoly.html I think @andrewvsutherland (the author) did a comparison at some point...

AndrewVSutherland commented 4 months ago

Better/simpler to test against b32.h (fast implementation of Barrett modular arithmetic for 32-bit integers) and m64.h (fast implementation of 64-bit Montgomery arithmetic); the latter is what ffpoly uses.