flintlib / flint

FLINT (Fast Library for Number Theory)
http://www.flintlib.org
GNU Lesser General Public License v3.0
445 stars 244 forks source link

Elementary functions worklist #1818

Open fredrik-johansson opened 2 years ago

fredrik-johansson commented 2 years ago

Things to do in the future:

fredrik-johansson commented 8 months ago

Plan to speed up elementary functions:

Implement table-based argument reduction as in https://hal.science/hal-04454093

The m-bitwise method should be used by default up to a few hundred bits and the bitwise method should be used at higher precision. For the lowest precisions things need to be hardcoded for best performance, but beyond that parameters could be configurable at runtime so that one can generate custom tables when one needs 10^7 function evaluations, etc (though there should be static default tables like now).

For the bitwise method, the idea is to conditionally subtract log(1+2^-i) for i = 0, ..., r. We can do this using r/2 mpn_sub_n calls on average. However, it should be possible to get better performance by converting to a smaller radix (e.g. 2^56) and doing the subtractions using carry-free operations. Presumably it will be fastest to do r masked subtractions with SIMD. This will be a bit tricky to implement.

For the reconstruction where one multiplies by (1+2^-i), combined shift-and-add mpn functions would be useful.

On to the series evaluation, which is essentially evaluating a polynomial with 1-word integer coefficients.

We can ideally assume that the reduced argument satisfies |x| < 2^-r where r >= 64 so that the precision drops one limb per series term. At higher precision, we will have |x| < 2^-(b*64) so that the precision drops b limbs per term. Probably r = 64 will be a bit bigger than optimal in the few-limb regime, but making the argument reduction efficient to allow r = 64 would be a good goal to shoot for, as this greatly simplifies the series evaluation.

Up to a few limbs, it will be best to do straight Horner with mulhighs; the additions are basically free, and the multiplications are cheap enough that something more clever won't save time.

From a few limbs, we want to do rectangular splitting like _arb_exp_taylor_rs, but there is a lot to speed up here. First, using mulhigh. Second, gradually adjusting the precision (1 limb increment per term if x < 2^-64 as noted above). Third, hardcoding things to minimize overhead.

Here is some sample code for n = 10 (see https://gist.github.com/fredrik-johansson/c1ab3ce215e1cf80115c422a103cf31f for complete test program including some more variants):

/* (a0 + a1 x + a2 x^2 + a3 x^3) + x^4 * ((a4 + a5 x + a6 x^2 + a7 x^3) + x^4* (a8 + a9 x) */
void
mpn_exp_series_10_rs4(mp_ptr res, mp_srcptr c, mp_srcptr x)
{
    mp_limb_t x2[8];
    mp_limb_t x3[8];
    mp_limb_t x4p[8];
    mp_limb_t s[11];
    mp_limb_t t[11];

    /* zero-pad so that we can fudge the final nx(n-1) mul as an nxn mul */
    mp_ptr x4 = x4p + 1;
    x4p[0] = 0;

    flint_mpn_sqrhigh(x2, x + 1, 8);
    flint_mpn_mulhigh_n(x3, x + 2, x2 + 1, 7);
    flint_mpn_sqrhigh(x4, x2 + 2, 6);

    umul_ppmm(t[1], t[0], x[8], c[9]);  /* t[1] = mpn_mul_1(t, x + 9 - 1, 1, c[9]); */
    t[2] = c[8];
    flint_mpn_mulhigh_n(s, t, x4 + 6 - 3, 3);
    s[3] = mpn_addmul_1(s, x3 + 7 - 3, 3, c[7]);
    s[4] = mpn_addmul_1(s, x2 + 8 - 4, 4, c[6]);
    s[5] = mpn_addmul_1(s, x  + 9 - 5, 5, c[5]);
    s[6] = c[4];
    flint_mpn_mulhigh_n(t, s, x4 + 6 - 7, 7);
    t[7] = mpn_addmul_1(t, x3, 7, c[3]);

    /* divide by factorial */
    flint_mpn_mulhigh_n(res, t, c0inv + 3, 8);

    /* + x^2 / 2 (9 limbs) */
    mpn_rshift(t, x2, 8, 1);
    res[8] = mpn_add_n(res, res, t, 8);
    /* + x (10 limbs) */
    res[9] = mpn_add_n(res, res, x, 9);
    /* + 1 */
    res[10] = 1;
}

This runs in 8.24e-8 seconds, compared to 2.67e-7 seconds for _arb_exp_taylor_rs (3.2x speedup!). Computing powers up to x^4 is just barely faster than computing up to x^3 here.

This is close to as fast as I can make it. Just a few things I can think of: