golang / go

The Go programming language
https://go.dev
BSD 3-Clause "New" or "Revised" License
121.63k stars 17.42k forks source link

math/big: improve performance of nat.mulRange #65027

Open griesemer opened 6 months ago

griesemer commented 6 months ago

On Sun, Jan 7, 2024 at 5:46 PM John Jannotti [jannotti@gmail.com](mailto:jannotti@gmail.com) wrote: I enjoy bignum implementations, so I was looking through nat.go and saw that mulRange is implemented in a surprising, recursive way,. In the non-base case, mulRange(a, b) returns mulrange(a, (a+b)/2) * mulRange(1+(a+b)/2, b) (lots of big.Int ceremony elided).

That's fine, but I didn't see any advantage over the straightforward (and simpler?) for loop.

z = z.setUint64(a)
for m := a + 1; m <= b; m++ {
    z = z.mul(z, nat(nil).setUint64(m))
}
return z

In fact, I suspected the existing code was slower, and allocated a lot more. That seems true. A quick benchmark, using the existing unit test as the benchmark, yields BenchmarkRecusiveMulRangeN-10 169417 6856 ns/op 9452 B/op 338 allocs/op BenchmarkIterativeMulRangeN-10 265354 4269 ns/op 2505 B/op 196 allocs/op

I doubt mulRange is a performance bottleneck in anyone's code! But it is exported as int.MulRange so I guess it's viewed with some value. And seeing as how the for-loop seems even easier to understand that the recursive version, maybe it's worth submitting a PR? (If so, should I create an issue first?)

griesemer commented 6 months ago

It should be possible to implement mulRange by allocating the close-to-correct amount of space up-front and then compute the product using mulAddVWW iteratively.

jannotti commented 6 months ago

Since nat.mul has the optimization to call z.mulAddWW when y has only one Word, the following allocates only once (on 64 bit platforms).

func (z nat) mulRangeIterative(a, b uint64) nat {
    switch {
    case a == 0:
        return z.setUint64(0)
    case a > b:
        return z.setUint64(1)
    }
    maxBits := uint64(bits.Len64(b)) * (b - a + 1)
    maxWords := (maxBits + _W - 1) / _W
    z = z.make(int(maxWords))
    z = z.setUint64(a)

    var buf [2]Word
    mb := nat(buf[:])
    for m := b; m > a; m-- {
        mb = mb.setUint64(m)
        z = z.mul(z, mb)
    }
    return z
}

maxBits will overflow with pathological inputs like mulRange(0, math.MaxUInt64). That overflow would not impact correctness (if it could finish), it will just start z too small. Of course, it will fail when it tries to grow it large enough.

Benchmarks:

func BenchmarkMulRangeNRecursive(b *testing.B) {
    b.ReportAllocs()
    for n := 0; n < b.N; n++ {
        r := mulRangesN[n%len(mulRangesN)]
        nat(nil).mulRange(r.a, r.b)
    }
}

func BenchmarkMulRangeNIterative(b *testing.B) {
    b.ReportAllocs()
    for n := 0; n < b.N; n++ {
        r := mulRangesN[n%len(mulRangesN)]
        nat(nil).mulRangeIterative(r.a, r.b)
    }
}

yields

BenchmarkMulRangeNRecursive-10       3400255           348.8 ns/op       589 B/op         21 allocs/op
BenchmarkMulRangeNIterative-10      11322468           104.6 ns/op        33 B/op          1 allocs/op
jannotti commented 6 months ago

Comments by Bakul Shah bakul@iitbombay.org led me to investigate very large inputs.

By building up large values on each "side" of the recursion, Karatsuba gets used for the larger multiplies and the recursive version begins to win. On mulRange(1_000, 200_000), it's more than 20 times faster than the single allocation iterative version.

I will write up a hybrid that uses the iterative for shorter spans which ought to get us the best of both worlds.

bakul commented 6 months ago

For reference: https://groups.google.com/g/golang-nuts/c/7kcFb41ARgM/m/5sSxZlKfAgAJ This can be easily parallelized though probably not worth it except for specific applications.

Some background and other interesting algorithms for factorials: http://www.luschny.de/math/factorial/FastFactorialFunctions.htm

Richard Fateman's CL implementations: https://people.eecs.berkeley.edu/~fateman/papers/factorial.pdf

jannotti commented 6 months ago

I have a simple hybrid implementation that looks like this:

func (z nat) mulRange(a, b uint64) nat {
    maxBits := uint64(bits.Len64(b)) * (b - a + 1)
    maxWords := (maxBits + _W - 1) / _W
    // use recursive algorithm if it looks like it will result in each side being
    // big enough for karatsuba.
    if maxWords > uint64(2*karatsubaThreshold) {
        return z.mulRangeRecursive(a, b)
    }
    return z.mulRangeIterative(a, b)
}

Unfortunately, the cutoff > 2 * karatsubaThreshold is a little low. In a benchmark with increasingly large ranges between a and b, there is a window of sizes, just above the cutoff, that is slightly slower. So the cutoff should be higher. However, I'm not sure it's worth a hardcoded value with perhaps an addition to calibration_test.go.

Would a "magic" cutoff like maxWords > uint64(5*karatsubaThreshold/2) or similar, as determined by hand tuning, be acceptable, or would you want to see something more rigorous?