flintlib / flint

FLINT (Fast Library for Number Theory)
http://www.flintlib.org
GNU Lesser General Public License v3.0
444 stars 245 forks source link

Faster NMOD_RED using different precomputation? #2061

Open vneiger opened 2 months ago

vneiger commented 2 months ago

NMOD_RED reduces a single limb modulo some given n. It currently calls NMOD_RED2 which reduces a two-limb (a_hi, a_lo) integer mod n, where a_hi must be < n (when called for NMOD_RED, it is zero).

Another way to reduce a single limb a mod n is to use the approach of n_mulmod_shoup to compute 1*a mod n, so with the fixed operand 1:

The precomputation is... a precomputation. It only depends on n, not on a.

The call to n_mulmod_shoup simply does one umul_ppmm to get some p_hi which is either floor(a/n) or floor(a/n)-1, then one single-word multiplication res = a - p_hi * n, and finally a single correction step (if res >= n, then return res - n else return res).

In this particular case, this actually works even when n has 64 bits, there is absolutely no restriction on n or a (usually n_mulmod_shoup requires n to have at most 63 bits).

NMOD_RED2 has an additional add_ssaaaa, some shifts here and there, and two correction steps.

Using the mulmod_shoup approach seems beneficial, for example when running nmod_vec/profile/p-reduce on zen4 I get the following values (left is existing code, right is using mulmod_shoup as showed below):

NMOD_RED        n_mulmod_shoup

bits 53, c/l = 4.0      bits 53, c/l = 2.4
bits 54, c/l = 4.1      bits 54, c/l = 2.4
bits 55, c/l = 4.1      bits 55, c/l = 2.4
bits 56, c/l = 4.1      bits 56, c/l = 2.4
bits 57, c/l = 4.1      bits 57, c/l = 2.4
bits 58, c/l = 4.1      bits 58, c/l = 2.4
bits 59, c/l = 4.1      bits 59, c/l = 2.4
bits 60, c/l = 4.1      bits 60, c/l = 2.4
bits 61, c/l = 4.1      bits 61, c/l = 2.4
bits 62, c/l = 4.1      bits 62, c/l = 2.4
bits 63, c/l = 4.1      bits 63, c/l = 2.4
bits 64, c/l = 2.8      bits 64, c/l = 2.4

(2.4 goes down to 2.0 if explicitly unrolling the loop; unrolling does not seem to help the NMOD_RED variant)

These timings were obtained by simply changing the reduce function, adding the precomputation and calling n_mulmod_shoup instead of NMOD_RED:

void _nmod_vec_reduce(nn_ptr res, nn_srcptr vec, slong len, nmod_t mod)
{
    const ulong one_precomp = n_mulmod_precomp_shoup(1L, mod.n);
    slong i;
    for (i = 0 ; i < len; i++)
    {
        //NMOD_RED(res[i], vec[i], mod);
        res[i] = n_mulmod_shoup(1L, vec[i], one_precomp, mod.n);
    }
}

There should be other places where this may be beneficial. However for more generalized use, it would probably help to add this one_precomp to the struct nmod_t, but I suspect there are good reasons for trying to avoid augmenting this struct?

fredrik-johansson commented 2 months ago

Impressive!

See also #1950, #1926, #1823. What we want ultimately is to introduce an n_mod context object that is passed by reference rather than by value and which is not stored inside matrices and polynomials. This will give us more flexibility to store additional useful precomputed data like this.

Can you also speed up reduction from 2 or 3 limbs building on the same trick?

vneiger commented 2 months ago

Impressive!

See also #1950, #1926, #1823. What we want ultimately is to introduce an n_mod context object that is passed by reference rather than by value and which is not stored inside matrices and polynomials. This will give us more flexibility to store additional useful precomputed data like this.

Sounds good indeed!

Can you also speed up reduction from 2 or 3 limbs building on the same trick?

It is usable for reducing from more limbs at least with naive approaches (*), but I'm not sure at the moment how efficient that would be. And this may restrict n to 63 bits. I would be surprised if it turns out to be faster than NMOD_RED2 (i.e. when a_hi is already reduced), but I would not be so surprised concerning the general 2-limb and the 3-limb cases (the current NMOD2_RED2 and NMOD_RED3).

(*) e.g. precompute not only q = floor(2**64 / n) but also the corresponding remainder r, and then compute (a_hi * 2**64 + a_lo) mod n as ((a_hi * r) mod n + a_lo mod n) mod n. The remainder r is fixed so can also benefit from mulmod_shoup (which accepts any limb a_hi). In fact, this can work for several limbs, and I guess the precomputations would then benefit from the new n_mulmod_and_precomp_shoup introduced in #2055 ; but I have no idea how it would compete with more elaborate algorithms in the case of many limbs.

vneiger commented 2 months ago

Some attempts at 2 and 3 limbs are rather promising (assuming I compare everything properly, this was done a bit quickly). This is cycles/limb, up to some constant factor, for reducing a length-1000 array.

First 4 columns: reducing each entry of the array modulo n which has bits bits. Next 3 columns: reducing each block of two consecutive entries seen as (a_hi, a_lo) modulo n. Last 4 columns: reducing each block of three consecutive entries (a_hi,a_me,a_lo) modulo n. Values 0.00 means the algorithm does not apply.

red, 2red2, red3 are NMOD_RED, NMOD2_RED2, NMOD_RED3. The others are various versions based on the idea above in this issue. A column xx_ur is the same function as the previous column, but with some manual unrolling. prcp3 or the next one xx_ur do not assume that a_hi is reduced, whereas red3 does.

bits    check   red     xx_ur   prcp    xx_ur   ||      2red2   prcp22  p22_bis ||      red3    xx_ur   prcp3   xx_ur
58      pass    4.03    4.21    2.11    1.89    ||      4.68    2.86    2.06    ||      3.27    3.18    2.73    2.07
59      pass    4.03    4.15    2.12    2.04    ||      4.78    3.01    2.09    ||      3.19    3.16    2.73    2.07
60      pass    4.03    4.17    2.12    1.89    ||      4.90    2.94    2.06    ||      3.20    3.16    2.97    2.08
61      pass    4.05    4.16    2.12    1.91    ||      4.82    2.88    2.05    ||      3.20    3.16    2.73    2.07
62      pass    4.04    4.17    2.14    1.95    ||      4.84    2.87    2.07    ||      3.20    3.16    2.73    2.08
63      pass    4.04    4.22    2.17    1.90    ||      4.68    2.87    2.07    ||      3.20    3.16    2.98    2.14
64      pass    2.79    2.98    2.11    1.89    ||      2.92    0.00    0.00    ||      2.08    2.20    0.00    0.00

I'll have a look at the many-limb case when possible.