python / cpython

The Python programming language
https://www.python.org
Other
62.15k stars 29.87k forks source link

Quadratic time internal base conversions #90716

Open tim-one opened 2 years ago

tim-one commented 2 years ago
BPO 46558
Nosy @tim-one, @cfbolz, @sweeneyde
Files
  • todecstr.py
  • todecstr.py
  • Note: these values reflect the state of the issue at the time it was migrated and might not reflect the current state.

    Show more details

    GitHub fields: ```python assignee = None closed_at = created_at = labels = ['interpreter-core', 'performance'] title = 'Quadratic time internal base conversions' updated_at = user = 'https://github.com/tim-one' ``` bugs.python.org fields: ```python activity = actor = 'tim.peters' assignee = 'none' closed = True closed_date = closer = 'tim.peters' components = ['Interpreter Core'] creation = creator = 'tim.peters' dependencies = [] files = ['50593', '50595'] hgrepos = [] issue_num = 46558 keywords = [] message_count = 9.0 messages = ['411962', '411966', '411969', '411971', '412120', '412122', '412172', '412191', '412192'] nosy_count = 3.0 nosy_names = ['tim.peters', 'Carl.Friedrich.Bolz', 'Dennis Sweeney'] pr_nums = [] priority = 'normal' resolution = 'wont fix' stage = 'resolved' status = 'closed' superseder = None type = 'performance' url = 'https://bugs.python.org/issue46558' versions = [] ```

    tim-one commented 1 year ago

    What is preventing a pure C implementation from being checked in?

    Nobody willing to endure the pain, for essentially 0 gain. Neil's PR leaves the code in Python, where the code is very short, elegant, & clear. These "very big int" algorithms spend essentially all their time inside integer multiplication (whether CPython's int Karatsuba, or the decimal module's fancier scheme), and those are already coded in C.

    The new asymptotically quicker divmod() is much more involved, and could probably benefit more from being coded in C (there is, e.g., no efficient way to pick a bit out of a bigint in Python - things like (value >> j) & 1 take O(value.bit_length() - j) time). Maybe that could, e.g., save a factor of log(n) in the asymptotics.

    But, as already said, after decades of this I'm unwilling this time to once again let ambition kill making major improvements. Optimality is explicitly not a goal.

    BTW. your "standalone C implementation to convert from decimal <=> binary" wouldn't really help, because of the "standalone" part: if you don't also have better-than-quadratic-time bigint multiplication (or for some algorithms, division) to build on, the overall base conversion algorithm will remain quadratic time.

    The usual "divide and conquer" base conversion algorithms run in O(M(n) * log(n)) time, where M(n) is the time to multiply two n-digit integers.

    tim-one commented 1 year ago

    @pochmann, some food for thought about your str->int algorithm using decimal. The asymptotics look great, but the constant factor not so great, probably reflecting that decimal's huge-precision division has a constant factor much higher than its huge-precision multiplication. Another hidden cost is that, because divisors are reused, decimal incurs the (hidden) expense of computing the same reciprocals repeatedly too.

    So how about trying without division? Multiply be pre-computed reciprocals instead, as floating-point values. The key insight here is that most of the troubles that come with that method don't arise in this context, because reciprocals of powers of 2 are all exactly representable in decimal. Computing reciprocals in this context would not itself be a source of numeric error.

    A floating result would have to be "quantized" (with round-to-0) to extract the integer part. and then there are a couple ways to get the remainder (all ways I can think of require another multiply).

    In effect, I think a divmod() could be replaced by 2 multiplies (+ some fiddling). Reciprocals can be precomputed simply, starting with 1/Decimal(2)**1024, and then repeatedly squaring.

    tim-one commented 1 year ago

    So how about trying without division? Multiply be pre-computed reciprocals instead, as floating-point values.

    Haha :wink:. A quick stab at that wasn't very promising. At first it ran 5 times slower. Head-scratching convinced me that was because, the way I was breaking a digit into quotient and remainder, decimal was convinced that the 0 digits after the remainder's decimal point were "significant", and so explicitly stored them (and so also crawled over them in later calculations).

    Using a different way got around that, but then the result was only about 10% faster than the original. All in all, appeared to have the same asymptotics (unsurprising), but not enough benefit in the constant factor to be worth the extra conceptual and code complexity.

    pochmann commented 1 year ago

    decimal's huge-precision division has a constant factor much higher than its huge-precision multiplication

    Well, four times higher in my mentioned testing, if I'm not thinking about this wrong.

    And the reciprocals are ~2.32 times longer. 2^1024 has 309 decimal digits, and 1/2^1024 has 716 significant decimal digits. Smaller example:

    2^30 = 1073741824 2^-30 = 9.31322574615478515625E-10

    tim-one commented 1 year ago

    And the reciprocals are longer. 2^1024 has 309 decimal digits, and 1/2^1024 has 716 decimal digits.

    I guess (but don't know - guessing based on what other packages do) that decimal computes a reciprocal approximation good to about the same number of decimal digits as the numerator. In context, 2**1024 is used to break up "digits" no larger than 2**2048, which can have 617 decimal digits.

    Python code could precompute the same narrower reciprocal approximations - but then it would have to deal with that as a new source of potential numeric error too.

    Then the code would certainly become more delicate and way harder to understand - I wasn't interested, myself, in trying to do "clever error analysis" at all here. Or it could simply truncate (toward 0) everywhere, and clearly end up with a quotient guess guaranteed to be "not too large". Then via repeated trial, so long as the remainder was >= than the power of 2 (T) used to split, increment the quotient by 1 and decrement the remainder by T. No fancy analysis needed then - but also no a priori guarantee than that such a loop wouldn't need to go around a great many times :wink:.

    pochmann commented 1 year ago

    Then via repeated trial, so long as the remainder was >= than the power of 2 (T) used to split, increment the quotient by 1 and decrement the remainder by T. No fancy analysis needed then - but also no a priori guarantee than that such a loop wouldn't need to go around a great many times 😉.

    Reminds me of this loop in Mark's divmod. I thought it might need to go around a great many times, but then I couldn't find any case where it ran more than once (or twice? Can't remember.)

    tim-one commented 1 year ago

    Reminds me of this loop in Mark's divmod.

    That's pretty shallow :wink:. Each loop in a call to _divmod_pos(a, b) goes around about a.bit_length() / b.bit_length() times. It's viewing a as a sequence of digits in base 2**b.bit_length(). If, e.g., a <= b**2, never more than twice. Pass a large power of b for a, and watch them go around as often as you care to force it.

    pochmann commented 1 year ago

    You're not talking about the loop that I linked to / highlighted, are you? The while r < 0: one.

    tim-one commented 1 year ago

    You're not talking about the loop that I linked to / highlighted, are you?

    I thought I was, but on my screen I see now that the highlighting is barely visible and my eye was drawn to the middle of the screen. The loop it appears you're really talking about:

    The while r > 0: one.

    shows up as the very topmost 3 lines of the screen.

    Yup, that's not obvious (to me) at all. In Burnikel and Ziegler's paper, this is no loop, but instead what you'd get if you unrolled Mark's loop twice. I assume (but don't know) Mark just wanted to cut the code bulk, and didn't care about making a third useless but dirt-cheap r < 0 test in the unlikely case q had to be decremented twice.

    But you'll have to read their paper for proof that twice is enough. I never fought to understanding of it, and found their exposition difficult to follow.

    tim-one commented 1 year ago

    Settling for computing reciprocal approximations does buy real speedup, saving about a third of the original code's cycles. This is somewhat aggressive, rounding the approximation of 1/T to just one more decimal digit of precision than T has. Nevertheless, the correction loops following don't find much in need of correcting. When they do, corrections in both directions can be needed, but so far I haven't seen an adjustment (to the quotient) of more than one needed.

    Alas, this still takes over twice the time of the PR code at 1 million bits, although the gap is shrinking.

    # at 10 million bits
    29.88990306854248 str_to_int_using_decimal
    20.070303916931152 dec2 # the code below, now the clear winner
    24.412405014038086 str_to_int # in Neil's PR

    And the code:

    def dec2inner(s, ctx, bits):
        from decimal import Decimal, ROUND_DOWN, MAX_PREC
        asint = Decimal(1)
        d = Decimal(s)
        div = Decimal(2) ** bits
        rdiv = 1 / div
        divs = []
        while div <= d:
            ctx.prec = div.adjusted() + 2
            divs.append((div, +rdiv)) # `+rdiv` rounda back rdiv
            ctx.prec = MAX_PREC
            div *= div
            rdiv *= rdiv
        digits = [d]
        for div, rdiv in reversed(divs):
            newdigits = []
            for digit in digits:
                q = (digit * rdiv).quantize(asint, rounding=ROUND_DOWN)
                r = digit - q * div
                if r < 0:
                    while True:
                        q -= 1
                        r += div
                        if r >= 0:
                            break
                else:
                    while r >= div:
                        q += 1
                        r -= div
                newdigits.append(q)
                newdigits.append(r)
            digits = newdigits
            if not digits[0]:
                del digits[0]
        b8 = bits // 8
        b = b''.join(int(digit).to_bytes(b8, 'big')
                     for digit in digits)
        return int.from_bytes(b, 'big')
    
    def dec2(s, nbits=1024):
        import decimal
        with decimal.localcontext() as ctx:
            ctx.prec = decimal.MAX_PREC
            ctx.Emax = decimal.MAX_EMAX
            ctx.Emin = decimal.MIN_EMIN
            return dec2inner(s, ctx, nbits)
    tim-one commented 1 year ago

    Surprising observation: in that code, for the 10-million digit case, 5% of the total time is spent in the final time rdiv *= rdiv is executed. Which is a result that's never used. That's easy to avoid, albeit obscure.

    The new version here does that, and sets the context to ROUND_DOWN so that the first q guess is never too large. That simplifies the correction code, but doesn't speed it. FYI, a correction appears to be needed in about a third of a percent of cases.

    def dec2inner(s, ctx, bits):
        from decimal import Decimal, MAX_PREC
        d = Decimal(s)
        div = Decimal(2) ** bits
        rdiv = 1 / div
        divs = []
        needrdiv = False
        while div <= d:
            if needrdiv:
                rdiv *= rdiv
            ctx.prec = div.adjusted() + 2
            # `+rdiv` rounds back rdiv; context is ROUND_DOWN
            divs.append((div, +rdiv))
            ctx.prec = MAX_PREC
            div *= div
            needrdiv = True
        digits = [d]
        for div, rdiv in reversed(divs):
            newdigits = []
            for digit in digits:
                q = (digit * rdiv).to_integral_value() # ctx truncates
                r = digit - q * div
                assert r >= 0
                while r >= div:
                    q += 1
                    r -= div
                newdigits.append(q)
                newdigits.append(r)
            digits = newdigits
            if not digits[0]:
                del digits[0]
        b8 = bits // 8
        b = b''.join(int(digit).to_bytes(b8, 'big')
                     for digit in digits)
        return int.from_bytes(b, 'big')
    
    def dec2(s, nbits=1024):
        import decimal
        with decimal.localcontext() as ctx:
            ctx.prec = decimal.MAX_PREC
            ctx.Emax = decimal.MAX_EMAX
            ctx.Emin = decimal.MIN_EMIN
            ctx.rounding = decimal.ROUND_DOWN
            return dec2inner(s, ctx, nbits)
    tim-one commented 1 year ago

    EDIT: changed int(Decimal) to int(str(Decimal)) because the latter is faster(!); removed the tail-recursion "optimization" because it added some obscurity (and an indentation level) for no measurable gain.

    I've run out of ideas for speeding the decimal version of string->int, so will just leave the current state here for posterity. It's about 15% faster than the last version at 10 million digits, quite better still at some other digit lengths, and takes less than twice the time of the code in Neil's PR at 1 million digits now.

    Major changes:

    CAUTION: this needs rigorous error analysis. "Haven't seen" doesn't mean "can't happen" ☹️.

    At 100 million digits, it's well over twice as fast as the original now:

    # at 100 million digits
    928.7468802928925  str_to_int in Neil's PR
    458.11627554893494 str_to_int_using_decimal
    289.3329448699951  dec2 # last version posted
    194.87010216712952 dec3 # the code below
    Click for decimal `str->int` code ```py def dec3(s, nbytes=128): from collections import defaultdict import decimal from decimal import Decimal, MAX_PREC cache = {} rcache = {} def cpow(n): if (result := cache.get(n)) is None: if n - 1 in cache: div = cache[n-1][0] * D256 rdiv = rcache[n-1] * rD256 elif n <= nbytes: div = D256 ** n rdiv = 1 / div else: div = cpow(n1 := n >> 1)[0] * cpow(n2 := n - n1)[0] rdiv = rcache[n1] * rcache[n2] rcache[n] = rdiv ctx.prec = div.adjusted() + 3 cache[n] = result = div, (+ rdiv) ctx.prec = MAX_PREC return result # `spread` maps a correction loop count to the number of times # that count was needed; all code involving it should be removed # when this is all debugged spread = defaultdict(int) result = bytearray() add = result.extend def inner(x, n): # assert 0 <= x < D256 ** n if n <= nbytes: # XXX Stefan Pochmann discovered that, for 1024-bit # ints, `int(Decimal)` took 2.5x longer than # `int(str(Decimal))`. So simplify this code to the # former if/when that gets repaired. add(int(str(x)).to_bytes(n, 'big')) return # If `n` is odd it's vital that we split on n//2 + 1 for the # approximations to be justified. If we split on the smaller # n//2, the correction loop can go around a lot more times, # unless we also keep more digits in `rdiv`, `x`, and the # multiplication, than `div.adjusted() + 3`. About 3 more # digits than that appear to suffice, but the more digits we # use the more expensive arithmetic. n2 = (n + 1) >> 1 div, rdiv = cpow(n2) ctx.prec = div.adjusted() + 3 q = (+x * rdiv).to_integral_value() # ctx truncates ctx.prec = MAX_PREC x -= q * div assert x >= 0 count = 0 while x >= div: count += 1 q += 1 x -= div spread[count] += 1 inner(q, n - n2) inner(x, n2) with decimal.localcontext() as ctx: ctx.prec = decimal.MAX_PREC ctx.Emax = decimal.MAX_EMAX ctx.Emin = decimal.MIN_EMIN ctx.rounding = decimal.ROUND_DOWN D256 = Decimal(256) rD256 = 1 / D256 x = Decimal(s) ctx.prec = 20 n = int(x.ln() / D256.ln() + 2) ctx.prec = MAX_PREC inner(x, n) print(sorted(spread.items())) del cache, rcache return int.from_bytes(result, 'big')```
    oscarbenjamin commented 1 year ago

    I've just taken the time to read through all of the above. I like the idea of implementing some of these operations in Python (especially if it means introducing FFT multiplication).

    That being said though the main complaint of the CVE was str -> int and I have a PR (gh-97550) which implements that in C using the obvious algorithm that reduces the problem to integer multiplication. I haven't timed all the various alternatives offered above but I did compare with gh-96673 and it looks like runtime is basically the same but I think that the C implementation has better memory overheads.

    oscarbenjamin commented 1 year ago

    Here comes the updated version. This is a proper implementation of Schönhage-Strassen following these notes.

    Very nice. A note to anyone wanting to run timings with this is that the parameter n (aka K = 2^k in other texts) should be chosen based on the bit size of the integers being multiplied rather than always set to 256. Larger integers should usually use larger values (although look at the "Fine-Grained Tuning" slide in the linked notes). GMP tunes thresholds for this empirically on a per architecture basis but broadly n should scale like sqrt(N) if N is the bit length of the inputs. In other words you approximately want to split a 1000000 bit input into 1000 digits each of which has 1000 bits. It is this square rooting of the problem size that leads to the log(log(N)) factor in the complexity of SSA.

    For example if multiplying a = 12345**500000; b = 13245**500000 then these are the timings:

    $ python ssa.py 
    Multiplying 6795820 numbers with n= 256
    SSA took 2.7558290379993196
    Multiplying 6795820 numbers with n= 512
    SSA took 0.7624775220001538
    Multiplying 6795820 numbers with n= 1024
    SSA took 0.6304274069998428
    Multiplying 6795820 numbers with n= 2048
    SSA took 0.5857483529998717
    Multiplying 6795820 numbers with n= 4096
    SSA took 0.709619131999716

    Note here that 2048**2 == 4194304 which is almost the closest power of 2 to the input size.

    You can see an example of the GMP thresholds here: https://gmplib.org/repo/gmp/file/tip/mpn/x86_64/skylake/gmp-mparam.h#l83

    pochmann commented 1 year ago

    @tim-one You could maybe save a little more time by converting small Decimals to int not directly but via str. For 1024-bit values, that seems about 2.5 times as fast.

    Test code, results, a note Code: ``` from timeit import timeit from random import getrandbits import decimal nbits = 1024 with decimal.localcontext() as ctx: ctx.prec = decimal.MAX_PREC ctx.Emax = decimal.MAX_EMAX ctx.Emin = decimal.MIN_EMIN for _ in range(3): digit = decimal.Decimal(getrandbits(nbits)) direct = timeit(lambda: int(digit), number=10000) viastr = timeit(lambda: int(str(digit)), number=10000) print(f'{direct / viastr = }') ``` Results: ``` direct / viastr = 2.5690248907712365 direct / viastr = 2.5946376697481535 direct / viastr = 2.6202966323441648 ``` I noticed this shortly after posting my original version, but didn't bother with it because that version was only relevant for very large numbers, where this speedup of the linear time part of the complexity was insignificant in the bigger picture. Maybe now with your improvements it would be worth doing.
    oscarbenjamin commented 1 year ago

    I don't understand why the emphasis here is on using the decimal module rather than making use of SSA (i.e. FFT multiplication) as shown above. Every other bigint library solves all of these problems by reducing every operation to multiplication and then focusing all optimisations on multiplication which means using SSA for large integers. If SSA was implemented then the other algorithms would all be trivial and integer multiplication would be faster (and then I don't think anyone would contemplate using decimal for any of this).

    tim-one commented 1 year ago

    I don't understand why the emphasis here is on using the decimal module

    Because decimal already implements NTT for multiplication of large values, and also division building on that with the same O() behavior. The code for that is already written in C, and is in a mature, widely used, and extensively tested external library (libmpdec) we don't maintain. It's here, right now, and has been for years. A competitive, production-quality SSA for CPython's bigints is still just a wish.

    I won't repeat yet again that ambition has forever been the fatal enemy of making real progress visible to CPython users here.

    tim-one commented 1 year ago

    @pochmann discovers that int(Decimal) takes about 2.5x longer than int(str(Decimal)) for 1024-bit ints

    Thanks - but sheesh :wink:. The original entry in this report noted that int <-> decimal.Decimal conversions are also quadratic time now, but, alas, nothing we've done so far appears to have addressed that (although Neil's PR contains a fast int -> decimal.Decimal function, CPython's decimal implementation doesn't use it).

    Making the change does help a little, but, as you noted, tends to get lost in the noise on "really big" inputs. It cuts some tens of thousands of digits off the smallest value at which the str -> int here ties the one in Neil's PR - but that's still over 3 million digits on my desktop box.

    tim-one commented 1 year ago

    Notes on SSA and base conversion:

    For an example of the last, I got much better str_to_int-adapted-to-use-SSA behavior after adding this to the start of SSA():

        if x1.bit_length() <= 10000 or x2.bit_length() <= 10000:
            p = x1 * x2
            return shift(p, M, 0) if M is not None else p

    By pursuing small things like that by hand for each different digit-string-length tried, it was easy enough to contrive a str_to_int_using_ssa (Neil's PR's str_to_int() with the two bigint multiplies replaced by SSA() calls) that beats dec3(). However, on digit strings of length one million (the target Stefan initially aimed at), the current str_to_int() is still a little faster. At 2 million digits, SSA beats str_to_int() (while dec3() doesn't win until reaching about 3.4 million digits).

    So:

    oscarbenjamin commented 1 year ago
    • But SSA appears to be hurting in at least these two respects: (1) as Oscar showed, the precise power-of-2 splitting factor used can have dramatic effects on speed, the hard-coded 256 is rarely "best", and it's unclear how to make a good dynamic choice

    I think this problem can be solved in the same way as for GMP although perhaps done more simply. Just run a bunch of timings and choose thresholds like:

    if N < 500:
        n = 64
    elif N < 1000:
        n =128
    # etc.

    That's how GMP used to work. The way GMP does it now is more complicated but that change in GMP is only credited with a smallish factor improvement. The paper describing that and other improvements is here: https://web.archive.org/web/20110720104333/http://www.loria.fr/~gaudry/publis/issac07.pdf

    The appropriate comparison when considering if it is worth using something like the SSA implementation above is not really with GMP but rather with the existing CPython implementation. In other words the bar to meet is that the code should be correct and then its speed should be not worse in any case and should be significantly better in at least some cases (to justify making any change at all).

    • For str->int the inputs it requires to beat the current PR's approach are so large that it would be astonishing if "DoS vulnerability" types cared one whit.

    Fair enough. Well if Neal's PR currently gives the fastest str->int then let's merge it. That would also make it easier to move other operations to pure Python e.g. for SSA etc later.

    vstinner commented 1 year ago

    I like Lib/_pylong.py feature: it was a long awaited feature: cool.

    I'm just curious about _pylong._DEBUG private attribute: is it really useful to ship it in a released Python version? Or is it only useful to people who designed and hacked Lib/_pylong.py? Can it be removed now?

    nascheme commented 1 year ago

    Can [the _DEBUG flag] be removed now?

    Yeah, I think it could be removed now since it is not very useful.

    vstinner commented 1 year ago

    Yeah, I think it could be removed now since it is not very useful.

    Ok, I created PR #99063 to remove _pylong._DEBUG flag.

    byeongkeunahn commented 1 year ago

    Hello, I have written a fast O(n lg n) integer multiplication code in Rust, using number-theoretic transform (NTT) with 64-bit primes. Code is available here.

    I ran some benchmarks, and surprisingly, it seems to outperform GMP 6.2.1 on my computer (Ryzen 7 2700X). The explanations for implementation detail can be found in the pull request I have opened in the rust-num/num-bigint repository. There I have attached a graph comparing the speed of this new code with GMP.

    Importantly, the code does not use any inline assemblies or architecture-specific intrinsics. Although some integer operations unavailable in portable C have been used, it should be possible to replace them with a few if-else statements at the expense of performance hit at most 2-4x (or much less, depending on the compiler optimization), which is suboptimal but still much better than the current CPython implementation.

    Note that the code correctly handles unbalanced multiplication according to the strategy explained here. Also, the padding is minimized (at most 7%) by using non-power-of-two radices.

    Although CPython currently uses 30-bit digits internally, this can be repacked into 64-bit words, ran through the new code, and transformed back to 30-bit digits. Since this transformation takes O(n) while the multiplication takes O(n lg n), there should be only minimal overhead for repacking.

    There should be no license issue since I took care to use only MIT-licensed codes. I'd like to port the code to C if there is enough interest for adopting my implementation in CPython.

    Please let me know. Thanks.

    oscarbenjamin commented 1 year ago

    I have written a fast O(n lg n) integer multiplication code in Rust

    Fantastic!

    I'd like to port the code to C if there is enough interest for adopting my implementation in CPython. Although some integer operations unavailable in portable C have been used, it should be possible to replace them with a few if-else statements at the expense of performance hit at most 2-4x

    I am not a CPython core developer myself but I would personally be very interested to see CPython have something like this. As you rightly imply the balance for CPython is best served by having a simple, portable implementation that does not need to have state of the art performance but does have good asymptotic complexity to avoid "surprising slowness" (essentially the root of the security concerns that precipitated this thread).

    I expect that the Python core development team would prefer for this to be something maintained outside of CPython e.g. like a library that CPython could depend on. That would also benefit many other projects as well because I think that there is a wide need for at least reasonable large integer performance that is more portable and has a more liberal license than GMP.

    Would you consider maintaining this as an independent C library?

    I ran some benchmarks, and surprisingly, it seems to outperform GMP 6.2.1 on my computer (Ryzen 7 2700X). Importantly, the code does not use any inline assemblies or architecture-specific intrinsics.

    I wonder if this is related to compiling directly for your local CPU rather than making a generic binary that is redistributable. One of the most important features of GMP is that it can bundle many different hand-written assembly versions for different CPU architectures in a "fat" build (--enable-fat) that can select the fastest version at runtime. This makes it possible in the context of Python packaging to have say an x86-64 Windows wheel for a Python package but still use much more selective CPU-specific routines at runtime.

    (Probably this last point about fat builds is not relevant for CPython.)

    gpshead commented 1 year ago

    This issue is about non-binary base conversions more so that multiplication. Regardless, it is ideal for that kind of thing to be a Rust and/or C library of its own. If we wanted to use it for anything we could depend upon that or vendor our own copy.

    oscarbenjamin commented 1 year ago

    This issue is about non-binary base conversions more so that multiplication.

    The way that large integer algorithms usually work is that pretty much all operations that are not trivially O(n) (meaning n bits) are reduced to multiplication and then multiplication gets heavily optimised. The benefits of that approach are obviously limited by the quality of the underlying multiplication algorithm though and CPython's multiplication algorithm scales poorly up to very large integers. The suggestion from @byeongkeunahn here is to improve the multiplication algorithm which then improves everything else (division, divmod, base conversion, ...).

    byeongkeunahn commented 1 year ago

    Yeah, it would be of course fine to wrap the code into an external library.

    I wonder if this is related to compiling directly for your local CPU rather than making a generic binary that is redistributable.

    I have tested the same executable on two more platforms (i5-9400 and Ryzen 5 7500F). The relative speed advantage stayed the same for the multiplication of very large integers, although below a few million bits GMP was 30-40% faster on those platforms. I'm not sure whether gmpy2 used architecture-specific routines; I just copied the Python benchmark code above.

    byeongkeunahn commented 11 months ago

    It seems the num-bigint crate attaches to Python 3 quite well with PyO3. The following image shows the performance of the new multiplication implementation on Ryzen 7 2700X, averaged over 10 runs. For small numbers, the num-bigint crate uses naive, Karatsuba, and Toom-3 before switching to number-theoretic transform. The cross-over point over the CPython native multiplication (version 3.11.3) is around 3,000 bits.

    benchmark-PyO3

    The binding I used is very simple. Further optimizations may be possible.

    use pyo3::prelude::*;
    use num_bigint::BigInt;
    
    #[pyfunction]
    fn mul(x: BigInt, y: BigInt) -> BigInt {
        &x * &y
    }
    
    #[pymodule]
    fn num_bigint_pyo3(_py: Python, m: &PyModule) -> PyResult<()> {
        m.add_function(wrap_pyfunction!(mul, m)?)?;
        Ok(())
    }
    import num_bigint_pyo3
    a, b = 10**100000, 9**100000
    x = num_bigint_pyo3.mul(a, b)
    assert x == a*b
    byeongkeunahn commented 11 months ago

    Here are the str -> int conversion timings with the new multiplication routine.

    Code ```python import sys sys.set_int_max_str_digits(0) import num_bigint_pyo3 from decimal import * from time import time from random import choices import numpy as np setcontext(Context(prec=MAX_PREC, Emax=MAX_EMAX, Emin=MIN_EMIN)) pow5 = [5] while len(pow5) <= 23: pow5.append(num_bigint_pyo3.mul(pow5[-1], pow5[-1])) def str_to_int_new_mul(s): def _str_to_int(l, r): if r - l <= 3000: return int(s[l:r]) lg_split = (r - l - 1).bit_length() - 1 split = 1 << lg_split return (num_bigint_pyo3.mul(_str_to_int(l, r - split), pow5[lg_split]) << split) + _str_to_int(r - split, r) return _str_to_int(0, len(s)) def str_to_int(s): def _str_to_int(l, r): if r - l <= 3000: return int(s[l:r]) lg_split = (r - l - 1).bit_length() - 1 split = 1 << lg_split return ((_str_to_int(l, r - split) * pow5[lg_split]) << split) + _str_to_int(r - split, r) return _str_to_int(0, len(s)) def str_to_int_using_decimal(s, bits=1024): d = Decimal(s) div = Decimal(2) ** bits divs = [] while div <= d: divs.append(div) div *= div digits = [d] for div in reversed(divs): digits = [x for digit in digits for x in divmod(digit, div)] if not digits[0]: del digits[0] b = b''.join(int(digit).to_bytes(bits//8, 'big') for digit in digits) return int.from_bytes(b, 'big') # Benchmark against int() on random 1-million digits string funcs = [int, str_to_int_using_decimal, str_to_int, str_to_int_new_mul] s = ''.join(choices('123456789', k=1_000_000)) expect = None for f in funcs: t_list = [] for _ in range(10): t = time() result = f(s) t_elapsed = time() - t # f.__name__) if expect is None: expect = result assert result == expect t_list.append(t_elapsed) mean, std = np.mean(t_list), np.std(t_list) print("{0}: {1:.9f} ± {2:.9f} sec".format(f.__name__.ljust(30), mean, std)) ```
    oscarbenjamin commented 11 months ago

    Here are the str -> int conversion timings with the new multiplication routine.

    It looks very nice to me. The question is what you would propose to do going on from here. If the suggestion is that CPython might depend on this (even in an optional way) then there are many details that would need to be worked out before that could be considered.

    byeongkeunahn commented 11 months ago

    It looks very nice to me. The question is what you would propose to do going on from here. If the suggestion is that CPython might depend on this (even in an optional way) then there are many details that would need to be worked out before that could be considered.

    Yes, I'd love to see CPython's multiplication improved asymptotically. It would also be great to improve other arithmetic operations too, but as a first step, I think it would be appropriate to limit the scope to multiplication.

    Currently the binding through PyO3 receives the raw bits of the integer through _PyLong_ToByteArray and sends it back to Python via _PyLong_FromByteArray. I think that mechanism can still work in the C code, which might help lessen the maintenance burden since the existing APIs don't need to be modified substantially.

    On the other hand, more deliberation is needed to port the primitive operations from Rust to C. Rust provides portable primitives such as add-with-carry and 128bit integer. Although it is possible to access the higher 64bit of a 64 x 64 multiplication in most C compilers, this would not be portable. Switching to plain portable C (without compiler intrinsics) is possible but it comes with performance degradations.

    For memory allocation, one call of malloc (and a corresponding call of free) would suffice. The temporary buffer has a size linear in the input integer and doesn't need to be a PyObject instance.

    Please let me know of other details that need to be worked out. Thanks!

    serhiy-storchaka commented 1 month ago

    Is there anything left to do in this issue or it can be closed?

    tim-one commented 1 month ago

    It's hard to tell because of the length and complexity of the discussion, but decimal.Decimal <-> int conversions remain quadratic-time. It's only str <-> int conversions that were effectively addressed.