python / cpython

The Python programming language
https://www.python.org
Other
63.48k stars 30.4k forks source link

Avoid overflow/underflow in math.prod() #85630

Open rhettinger opened 4 years ago

rhettinger commented 4 years ago
BPO 41458
Nosy @tim-one, @rhettinger, @mdickinson, @vedgar, @pablogsal, @websurfer5
Files
  • sum_and_prod.py: Low memory variant with fast path
  • Note: these values reflect the state of the issue at the time it was migrated and might not reflect the current state.

    Show more details

    GitHub fields: ```python assignee = None closed_at = None created_at = labels = ['type-bug', 'library', '3.10'] title = 'Avoid overflow/underflow in math.prod()' updated_at = user = 'https://github.com/rhettinger' ``` bugs.python.org fields: ```python activity = actor = 'tim.peters' assignee = 'none' closed = False closed_date = None closer = None components = ['Library (Lib)'] creation = creator = 'rhettinger' dependencies = [] files = ['49373'] hgrepos = [] issue_num = 41458 keywords = [] message_count = 28.0 messages = ['374693', '374734', '374737', '374745', '374748', '374914', '374915', '374921', '374927', '374928', '374929', '374943', '374959', '374960', '374962', '374964', '374966', '374968', '374969', '374970', '374984', '375048', '375049', '375052', '375054', '375075', '375078', '375091'] nosy_count = 6.0 nosy_names = ['tim.peters', 'rhettinger', 'mark.dickinson', 'veky', 'pablogsal', 'Jeffrey.Kintscher'] pr_nums = [] priority = 'normal' resolution = None stage = None status = 'open' superseder = None type = 'behavior' url = 'https://bugs.python.org/issue41458' versions = ['Python 3.10'] ```

    rhettinger commented 4 years ago

    For float inputs, the math.prod() function could be made smarter about avoiding overflow() and underflow(). That would also improve commutativity as well.

    Other tools like math.hypot() already take measures to avoid overflow/underflow and to improve commutativity. This makes them nicer to use than näive implementations.

    The recipe that I've been using for years is shown below. It certainly isn't the best way, but it is O(n) and always avoids overflow and underflow when possible. It has made for a nice primitive when implementing other functions that need to be as robust as possible. For example, in the quadratic formula, the √(b²-4ac) factors to b√(1-4ac/b²) and the rightmost term gets implemented in Python as product([4.0, a, c, 1.0/b, 1.0/b]).

    >> from math import prod, fabs >> from collections import deque >> from itertools import permutations

    >>> def product(seq):
        s = deque()
        for x in seq:
            s.appendleft(x) if fabs(x) < 1.0 else s.append(x)
        while len(s) > 1:
            x = s.popleft() * s.pop()
            s.appendleft(x) if fabs(x) < 1.0 else s.append(x)
        return s[0] if s else 1.0
    
    >>> data = [2e300, 2e200, 0.5e-150, 0.5e-175]
    >>> for factors in permutations(data):
        print(product(factors), prod(factors), sep='\t')

    1e+175 inf 1.0000000000000001e+175 inf 1e+175 inf 1e+175 1e+175 1.0000000000000001e+175 inf 1.0000000000000001e+175 1.0000000000000001e+175 1.0000000000000001e+175 inf 1e+175 inf 1.0000000000000001e+175 inf 1.0000000000000001e+175 1.0000000000000001e+175 1e+175 inf 1e+175 1e+175 1e+175 inf 1e+175 1e+175 1.0000000000000001e+175 inf 1.0000000000000001e+175 1.0000000000000001e+175 1e+175 0.0 1.0000000000000001e+175 0.0 1.0000000000000001e+175 inf 1.0000000000000001e+175 1.0000000000000001e+175 1e+175 inf 1e+175 1e+175 1.0000000000000001e+175 0.0 1e+175 0.0

    For math.prod(), I think a better implementation would be to run normally until an underflow or overflow is detected, then back up a step and switch-over to pairing low and high magnitude values. Since this is fallback code, it would only affect speed to the extent that we test for overflow or underflow at every step. Given how branch prediction works, the extra tests might even be free or at least very cheap.

    The math module functions usually go the extra mile to be more robust (and often faster) than a näive implementation. These are the primary reasons we teach people to prefer sqrt() over x**2, log1p(x) over log(1+x), prod(seq) over reduce(mul, seq, 1.0), log2(x) over log(x, 2), fsum(seq) over sum(seq), and hypot(x,y) over sqrt(x**2 + y**2). In each case, the library function is some combination of more robust, more accurate, more commutative, and/or faster than a user can easily create for themselves.

    mdickinson commented 4 years ago

    This message from Tim, starting "I'd like to divorce prod() from floating-point complications", seems relevant here: https://bugs.python.org/issue35606#msg333090

    fe5a23f9-4d47-49f8-9fb5-d6fbad5d9e38 commented 4 years ago

    Yes, fprod would be nice [though if you just want to avoid over/underflows, much easier solution is to first determine the sign, then sum the logarithms of absolute values, and exponentiate that]. But I agree with Tim that it should be a separate function. For the same reason that sum is not fsum. (The reason prod is in math is bureaucratic, not ontologic.)

    rhettinger commented 4 years ago

    The existing math.prod() already has a separate code path for floats. The proposal is to add an overflow/underflow check to that existing path so that we don't get nonsense like 0.0, 1e+175, or Inf depending on the data ordering. That doesn't warrant a separate function.

    FWIW fsum() is a separate function for several reasons, none of which apply to the current proposal: 1) we didn't already have a math.sum(). 2) All inputs types get converted to float. 3) Even in common cases, it is measurably slower that sum(). 4) It has a different signature than sum().

    pablogsal commented 4 years ago

    I think what Raymond proposes makes sense and it will certainly add value, especially given the mentioned expectations on what an implementation on the stdlib should have. The only think I would like to know is how much code/measured performance impact this will have. I expect this not to be a problem but I think is an important factor to help us decide.

    rhettinger commented 4 years ago

    Here's variant (minimally tested) that has a fast path for the common case (no overflow or underflow), that uses a list instead of a deque, that keeps memory use small (only storing pending values that can't be combined without triggering an overflow or underflow).

    ----------------------------

    from math import fabs, prod, isfinite
    
    def product(seq, start=1.0):
        total = start
        s = []                 # values that would overflow the total
        s_side = False         # true if s_i increase the magnitude of the product
        for x in seq:
            old_total = total
            total *= x
            underflow = not total and old_total and not x
            if isfinite(total) and not underflow:
                continue       # fast-path for multiplies that doesn't overflow/underflow
            total = old_total
    
            side = fabs(x) > 1.0
            if not s or side == s_side:
                s.append(x)
                s_side = side
                continue
    
            t = [x, total]
            while s and t:      # opposite sides:  matter and antimatter
                x = s.pop() * t.pop()
                side = fabs(x) > 1.0
                s.append(x) if side == s_side else t.append(y)
            if t:
                s = t
                s_side = not s_side
            total = s.pop()
        return prod(s, start=total)
    rhettinger commented 4 years ago

    Two edits:

    - return prod(s, start=total)
    +    for x in s:
    +        total *= x
    +    return total
    fe5a23f9-4d47-49f8-9fb5-d6fbad5d9e38 commented 4 years ago

    s.append(x) if side == s_side else t.append(x)

    To me, (s if side == s_side else t).append(x) seems much better. Not only more is factored, but .append is viewed as a procedure (returning None, changing its object), and if-expression is really used for its value, not the side-effect.

    mdickinson commented 4 years ago

    If we want to do this (and I'm still not convinced that we do), I think there's a simpler way: use frexp to decompose each float into a fraction and an exponent, multiply the fractions (which barring zeros will all be in [0.5, 1.0)), and keep track of the accumulated exponents separately. Then combine everything at the end.

    There's a possibility of the accumulated product of the fractions underflowing, but only every 1000 floats or so, so it's enough to check every 500 floats (say) whether the product is smaller than 2*-500 or not, and scale by 2*\1000 (adjusting the exponent correspondingly) if not.

    mdickinson commented 4 years ago

    Here's code to illustrate the idea. It doesn't yet handle zeros, infinities or nans; that support would need to be added.

    import math
    
    def fprod(numbers):
        # Product of numbers, avoiding intermediate underflow and overflow.
        # Does not handle zeros, infinities or nans
    # Running product is acc_m * 2**acc_e
    acc_m, acc_e = float.fromhex("1p1000"), -1000
        count = 0
        for number in numbers:
            m, e = math.frexp(number)
            acc_m *= m
            acc_e += e
            if count == 1000:
                if acc_m < 1.0:
                    acc_m = math.ldexp(acc_m, 1000)
                    acc_e -= 1000
                count = 0
    
        return math.ldexp(acc_m, acc_e)
    mdickinson commented 4 years ago

    Whoops. There's a missing count += 1 in there, of course.

    tim-one commented 4 years ago

    I'm skeptical of the need for - and wisdom of - this. Where does it come up? I can't think of any context where this would have been useful, or of any other language or package that does something like this. Long chains of mults are unusual outside of integer combinatorics to begin with.

    Closest I can recall came up in the Spambayes project. There we needed to compute the sum of the logs of a potentially large number of probabilities. But log (lack of!) speed proved to be a bottleneck, so I changed it to compute the log of the product of the probabilities. That product was all but certain to underflow to 0.0. But "for real" - no rearrangement of terms would have made any difference to that. Instead, akin to what Mark spelled out, every now & again frexp() was used to push the product bits back into [0.5, 1.0), and the power-of-2 exponent was tracked apart from that.

    fsum() is primarily aimed at a very different problem: improving the low-order bits.

    How to explain what this change to prod() would do? It may or may not stop spurious overflows or underflows, and the result depends on the order of the multiplicands. But in what way? If we can't/won't define what it does, how can other implementations of Python reproduce CPython's result?

    While I don't (yet?) see a real need for it, one thing that could work: compute the product left to right. Full speed - no checks at all. If the result is a 0, a NaN, or an infinity (which is bound to be rare in real life), do it again left to right with the frexp() approach. Then it can be explained: the result is what you'd get from left-to-right multiplication if the hardware had an unbounded exponent field, suffering overflow/underflow at the end only if the true exponent has magnitude too large for the hardware.

    rhettinger commented 4 years ago

    [Uncle Timmy]

    I'm skeptical of the need for - and wisdom of - this.

    Granted, the need is not common. On the other hand, if we can do it cheaply, why wouldn't we? Unnecessary overflow or underflow is never a desirable outcome. Currently, users do not have an easy and fast way to avoid that outcome.

    This comes up a lot when I use or teach how to apply Hypothesis to test floating point code. That tool makes you acutely aware of how much your functions have to be constrained to assure a useful output. And if you aren't satisfied with those constraints, it can be hard to fix unless you have robust primitives. Presumably that is why numpy uses pairwise summation instead of straight addition for example.

    To me, this is in the same category fixing the overflow error in bisect or using scaling in hypot() to avoid overflow or underflow. The core idea is to make the primitives as robust as possible if it can be done with only a minor performance impact.

    rhettinger commented 4 years ago

    FWIW, the occasions where this mattered all involved a mix of multiplications and divisions that mostly cancel out.

    The quadratic formula example is typical: product([4.0, a, c, 1.0/b, 1.0/b].

    Or a floating point implementation of comb(): product([1000, 999, 998, 997, 1/4, 1/3, 1/2, 1/1])

    Or terms in series expansions where both the numerator and denominator have many factors.

    rhettinger commented 4 years ago

    I attached a file with latest (and more tested) recipe. The underflow test was fixed: "not total and old_total and x". Also, the final call to math.prod() was in-lined so that I could time it with PyPy.

    tim-one commented 4 years ago

    See "wisdom" earlier ;-) It's ad hoc trickery that seemingly can't be explained without showing the precise implementation in use today. As already mentioned, frexp() trickery _can_ be explained: exactly what you'd get if left-to-right HW multiplication were performed with an unbounded exponent, over- and under-flowing if and only if the infinite precision exponent at the end "doesn't fit". It completely eliminates spurious over-/under-flow. If someone wants that, it would be suitable for an fprod() function (which, like fsum(), makes strong guarantees at real cost).

    Precisely which cases does this other thing protect against? Is there any way to characterize them beyond "well, try it and see"?

    If there's an actual problem here, doesn't it deserve to be fixed? Just making it "less common" in some unquantifiable, unexplainable way seems unprincipled, and _possibly_ even harmful (e.g., by making it harder to provoke the underlying still-there problem without white box adversarial testing).

    Presumably that is why numpy uses pairwise summation instead of straight addition for example.

    Loss of precision in fp summation is an extremely common real-life problem with real-life consequences. That's why you can find dozens of papers, over the decades, on schemes to improve that. Adding a vector of a million floats is common; multiplying a vector of even a hundred floats is rare.

    Similarly, scaled hypot implementations have been re-invented dozens of times. Etc. When there's "a real" problem in fp life, it's hard _not_ to find mounds of prior art trying to address it.

    You can find a few pieces on avoiding spurious under-/over-flow when chaining multiplies and divides, primarily older papers talking about 32-bit float formats with very limited dynamic range. The common advice: rearrange the handful of inputs in the source code by hand, based on your application knowledge of their likely magnitudes, to make left-to-right evaluation "safe".

    tim-one commented 4 years ago

    I may well have misread the code, believing it can still allow spurious over/underflows. On second reading of the current file, I don't know - it's more complicated than I thought.

    If it does guarantee to prevent them, then I shift from -1 to (probably ) -0.

    rhettinger commented 4 years ago

    The algorithm stops all spurious overflows and underflows. If favorable cancellations exist, it finds them.

    Algorithm in Words ------------------

    For every x in the sequence, multiply onto the total if possible.

    If x and "total" can't be combined without overflow/underflow, then x is given a "side" depending on |x| > 1.0. This indicates whether multiplying by x would increase the magnitude or decrease it.

    Note that "total" has the same side as "x". If one would increased magnitude and the other would decreased it, then our "total *= x" would have succeeded.

    The list "s" has the other pending multiplies. Each of them are on the same side — either they all increase the magnitude or they all decrease it. The direction is stored in the "s_side" variable.

    If "x" is on the same side as everything else in "s", we just append it. No cancels are possible.

    If "x" and "t" are on the opposite side of the elements in "s", then we multiply the big/little pairs to get favorable cancellations. Any t_i will successfully combine with any s_i.

    At the end of the loop, "total" and "s" are both on the same side and no further favorable cancellations are possible.

    rhettinger commented 4 years ago

    Variable name correction:

    Extreme points will successfully combine:
    >>> float_info.max * float_info.min
    3.9999999999999996
    tim-one commented 4 years ago

    Cool! So looks like you could also address an accuracy (not out-of-range) thing the frexp() method also does as well as possible: loosen the definition of "underflow" to include losing bits to subnormal products. For example, with the inputs

    >>> xs = [1e-200, 1e-121, 1e308]
    >>> product(xs)
    9.98012604599318e-14
    >>> _.hex()
    '0x1.c17704f58189cp-44'
    >>> math.prod(xs) # same thing
    9.98012604599318e-14
    >>> prod2(xs) # a tiny implementation of the frexp() method
    1e-13
    >>> _.hex()
    '0x1.c25c268497682p-44'

    product/math.prod get only a handful of the leading result bits right, because

    >>> 1e-200 * 1e-121
    1e-321
    >>> _.hex()
    '0x0.00000000000cap-1022'

    loses all but a handful of the trailing bits due to being in the subnormal range. Changing the order can repair that:

    >>> 1e-200 * 1e308 * 1e-121
    1e-13

    IOW, so far as "underflow" goes, we _start_ losing bits when the product becomes subnormal, well before it reaches 0.

    rhettinger commented 4 years ago

    I had not put any thought into subnormals (denormals?) because they didn't arise in my use cases. But it would be nice to handle them as well as possible.

    Accuracy improvements are welcome as well. And that goes hand in hand with better commutativity.

    The frexp() approach looks cleaner than my big/little matching algorithm and it doesn't require auxiliary memory. It may slower though. The matching algorithm only kicks in when overflows or underflows occur; otherwise, it runs close to full speed for the common case.

    tim-one commented 4 years ago

    "Denormal" and "subnormal" mean the same thing. The former is probably still in more common use, but all the relevant standards moved to "subnormal" some years ago.

    Long chains of floating mults can lose precision too, but hardly anyone bothers to do anything about that. Unlike sums, they're just not common or in critical cores. For example, nobody ever computes the determinant of a million-by-million triangular matrix by producting ;-) the diagonal.

    I only recall one paper devoted to this, using "doubled precision" tricks like fsum employs (but much more expensive unless hardware fused mul-add is available):

    "Accurate Floating Point Product and Exponentiation" Stef Graillat

    However, that does nothing to prevent spurious overflow or underflow.

    Much simpler: presumably pairwise products can enjoy lower accumulated error for essentially the same reasons pairwise summation "works well". Yet, far as I know, nobody bothers.

    Here's a cute program:

    if 1:
        from decimal import Decimal
        from random import random, shuffle, seed
    
        def pairwise(xs, lo, hi):
            n = hi - lo
            if n == 1:
                return xs[lo]
            elif n == 2:
                return xs[lo] * xs[lo + 1]
            else:
                mid = (lo + hi) // 2
                return pairwise(xs, lo, mid) * pairwise(xs, mid, hi)
    
        N = 4000
        print(f"using {N=:,}")
        for s in (153,
                  53,
                  314,
                  271828,
                  789):
            print("\nwith aeed", s)
            seed(s)
            xs = [random() * 10.0 for i in range(N)]
            xs.extend(1.0 / x for x in xs[:])
            shuffle(xs)
            print("prod   ", math.prod(xs))
            print("product", product(xs)) # the code attached to this report
            print("frexp  ", prod2(xs))   # straightforward frexp
            print("Decimal", float(math.prod(map(Decimal, xs))))
            print("pair   ", pairwise(xs, 0, len(xs)))

    By construction, the product of xs should be close to 1.0. With N=4000 as given, out-of-range doesn't happen, and all results are pretty much the same:

    using N=4,000

    with aeed 153 prod 0.9999999999999991 product 0.9999999999999991 frexp 0.9999999999999991 Decimal 1.0000000000000042 pair 1.0000000000000016

    with aeed 53 prod 1.0000000000000056 product 1.0000000000000056 frexp 1.0000000000000056 Decimal 1.0000000000000002 pair 0.9999999999999997

    with aeed 314 prod 1.0000000000000067 product 1.0000000000000067 frexp 1.0000000000000067 Decimal 1.0000000000000082 pair 1.0000000000000002

    with aeed 271828 prod 0.9999999999999984 product 0.9999999999999984 frexp 0.9999999999999984 Decimal 1.0000000000000004 pair 1.0000000000000064

    with aeed 789 prod 0.9999999999999994 product 0.9999999999999994 frexp 0.9999999999999994 Decimal 1.0 pair 1.0000000000000069

    But boost N so that out-of-range is common, and only frexp and Decimal remain reliable. Looks almost cetain that product() has serious bugs:

    using N=400,000

    with aeed 153 prod 0.0 product 1980.1146715391837 frexp 0.999999999999969 Decimal 1.000000000000027 pair nan

    with aeed 53 prod 0.0 product 6.595056534948324e+24 frexp 1.0000000000000484 Decimal 1.0000000000000513 pair nan

    with aeed 314 prod 0.0 product 6.44538471095855e+60 frexp 0.9999999999999573 Decimal 0.999999999999934 pair nan

    with aeed 271828 prod inf product 2556126.798990014 frexp 0.9999999999999818 Decimal 0.9999999999999885 pair nan

    with aeed 789 prod 0.0 product 118772.89118349401 frexp 0.9999999999999304 Decimal 1.0000000000000053 pair nan

    rhettinger commented 4 years ago

    Here's a pairwise variant:

        def prod(seq):
            stack = []
            exp = 0
            for i, x in enumerate(seq, start=1):
                m, e = frexp(x)
                exp += e
                stack += [m]
                while not i&1:
                    i >>= 1
                    x, y = stack[-2:]
                    stack[-2:] = [x * y]
            total = 1.0
            while stack:
                total *= stack.pop()
            return ldexp(total, exp)
    tim-one commented 4 years ago

    Well, that can't work: the most likely result for a long input is 0.0 (try it!). frexp() forces the mantissa into range [0.5, 1.0). Multiply N of those, and the result _can_ be as small as 2-N. So, as in Mark's code, every thousand times (2-1000 is nearing the subnormal boundary) frexp() is used again to force the product-so-far back into range. That's straightforward when going "left to right".

    With fancier reduction schemes, "it varies". Aiming for "obviously correct" rather than for maximal cleverness ;-) , here I'll associate each partial product with an integer e such that it's guaranteed (in the absence of infinities, NaNs, zeroes), abs(partial_product) >= 2^^-e. Then quick integer arithmetic can be used in advance to guard against a partial product underflowing:

        def frpair(seq):
            from math import frexp, ldexp
            stack = []
            exp = 0
            for i, x in enumerate(seq, start=1):
                m, e = frexp(x)
                exp += e
                stack += [(m, 1)]
                while not i&1:
                    i >>= 1
                    (x, e1), (y, e2) = stack[-2:]
                    esum = e1 + e2
                    if esum >= 1000:
                        x, e = frexp(x)
                        exp += e
                        y, e = frexp(y)
                        exp += e
                        esum = 2
                    stack[-2:] = [(x * y, esum)]
            total = 1.0
            totale = 0
            while stack:
                x, e = stack.pop()
                totale += e
                if totale >= 1000:
                   total, e = frexp(total)
                   exp += e
                   x, e = frexp(x)
                   exp += e
                   totale = 2
                total *= x
            return ldexp(total, exp)

    But I see no obvious improvement in accuracy over "left to right" for the uniformly random test cases I've tried.

    rhettinger commented 4 years ago

    But I see no obvious improvement in accuracy over "left to right" for the uniformly random test cases I've tried.

    Same here.

    tim-one commented 4 years ago

    More extensive testing convinces me that pairing multiplication is no real help at all - the error distributions appear statistically indistinguishable from left-to-right multiplication.

    I believe this has to do with the "condition numbers" of fp addition and multiplication, which are poor for fp addition and good for fp multiplication. Intuitively, fp addition systematically loses mounds of information whenever two addends in different binades are added (the lower bits in the addend in the binade closer to 0 are entirely lost). But the accuracy of fp mult couldn't care which less which binades the inputs are in, provided only the result doesn't overflow or become subnormal.

    For "random" vectors, pairing summation tends to keep addends close together in magnitude, which is "the real" reason it helps. Left-to-right summation tends to make the difference in magnitude increase as it goes along (the running sum keeps getting bigger & bigger).

    So, red herring. The one thing that _can_ be done more-or-less straightforwardly and cheaply for fp mult is to prevent spurious overflow and underflow (including spurious trips into subnormal-land).

    rhettinger commented 4 years ago

    IIRC, both Factor and Julia use pairwise multiplication. I'm guessing that the reason is that if you have an associative-reduce higher order function, you tend to use it everywhere even in cases where the benefits are negligible ;-)

    tim-one commented 4 years ago

    Or, like I did, they succumbed to an untested "seemingly plausible" illusion ;-)

    I generated 1,000 random vectors (in [0.0, 10.0)) of length 100, and for each generated 10,000 permutations. So that's 10 million 100-element products overall. The convert-to-decimal method was 100% insensitive to permutations, generating the same product (default decimal prec result rounded to float) for each of the 10,000 permutations all 1,000 times.

    The distributions of errors for the left-to-right and pairing products were truly indistinguishable. They ranged from -20 to +20 ulp (taking the decimal result as being correct). When I plotted them on the same graph, I thought I had made an error, because I couldn't see _any_ difference on a 32-inch monitor! I only saw a single curve. At each ulp the counts almost always rounded to the same pixel on the monitor, so the color of the second curve plotted almost utterly overwrote the first curve.

    As a sanity check, on the same vectors using the same driver code I compared sum() to a pairing sum. Pairing sum was dramatically better, with a much tighter error distribution with a much higher peak at the center ("no error"). That's what I erroneously expected to see for products too - although, in hindsight, I can't imagine why ;-)