ziglang / zig

General-purpose programming language and toolchain for maintaining robust, optimal, and reusable software.
https://ziglang.org
MIT License
35.07k stars 2.56k forks source link

Investigate removing branch from `ArrayList.ensureTotalCapacity` #15574

Closed Validark closed 1 year ago

Validark commented 1 year ago

I noticed this code in ArrayList.ensureTotalCapacity: https://github.com/ziglang/zig/blob/91b4729962ddec96d1ee60d742326da828dae94a/lib/std/array_list.zig#L363-L380

Specifically:

https://github.com/ziglang/zig/blob/91b4729962ddec96d1ee60d742326da828dae94a/lib/std/array_list.zig#L374-L377

I hope this could somehow be removed in favor of a branchless option, so what follows is an investigation into this possibility, just for fun.

First, I will disregard the fact that we are dealing with 64 bit integers, disregard the fact that doing an integer division by 2 can floor the true quotient by 0.5 when the dividend is odd, and disregard the saturating arithmetic and just write this in terms of a recursive sequence.

$$\large \begin{equation} \begin{split}
U_0 &= \texttt{capacity}\ Un &= U{n-1} \times 1.5 + 8\ \end{split} \end{equation}$$

If you remember your pre-calculus class, this recursive sequence is called "shifted geometric", because it has a multiply that is being shifted by an addition. For $\large U_0 = c$, the expansion of this recursive sequence looks like:

$$\large \begin{equation} \begin{split} U_1 = \nobreakspace &&&&&c \times 1.5 + 8\ U_2 = \nobreakspace &&&&(&c \times 1.5 + 8) \times 1.5 + 8\ U_3 = \nobreakspace &&&(&(&c \times 1.5 + 8) \times 1.5 + 8) \times 1.5 + 8\ U_4 = \nobreakspace &&(&(&(&c \times 1.5 + 8) \times 1.5 + 8) \times 1.5 + 8) \times 1.5 + 8\ U_5 = \nobreakspace &(&(&(&(&c \times 1.5 + 8) \times 1.5 + 8) \times 1.5 + 8) \times 1.5 + 8) \times 1.5 + 8\ \end{split} \end{equation}$$

To get the general equation, let's replace $\large 1.5$ with $\large r$ and $\large 8$ with $\large d$

$$\large \begin{equation} \begin{split} U_0 = \nobreakspace &&&&&c \ U_1 = \nobreakspace &&&&&c \times r + d\ U_2 = \nobreakspace &&&&(&c \times r + d) \times r + d\ U_3 = \nobreakspace &&&(&(&c \times r + d) \times r + d) \times r + d\ U_4 = \nobreakspace &&(&(&(&c \times r + d) \times r + d) \times r + d) \times r + d\ U_5 = \nobreakspace &(&(&(&(&c \times r + d) \times r + d) \times r + d) \times r + d) \times r + d\ \end{split} \end{equation}$$

Let's apply the distributive property of multiplication:

$$\large \begin{equation} \begin{split}
U_1 &=cr^1 &&&&+ dr^0\ U_2 &= cr^2 &&&+ dr^1 &+ dr^0\ U_3 &= cr^3 &&+ dr^2 &+ dr^1 &+ dr^0\ U_4 &= cr^4 &+ dr^3 &+ dr^2 &+ dr^1 &+ dr^0\ U_5 &= cr^5 + dr^4 &+ dr^3 &+ dr^2 &+ dr^1 &+ dr^0\ \end{split} \end{equation}$$

The pattern here is pretty obvious. We can express it using $\large \Sigma$ notation:

$$\large Un = cr^n + \sum{i=1}^{n} dr^{i-1}$$

You may notice that the $\large \Sigma$ term is the "sum of a finite geometric sequence". Replacing that term with the well-known formula for that allows us to write an explicit function:

$$\large f(n) = cr^n + d \left(\frac{1 - r^n}{1 - r}\right)$$

Let's put $\large 1.5$ back in for $\large r$ and $\large 8$ back in for $\large d$ and assess the damage:

$$\large f(n) = c \times 1.5^n + 8 \left(\frac{1 - 1.5^n}{1 - 1.5}\right)$$

Luckily, we can simplify $\large (1 - 1.5)$ to $\large -0.5$. Dividing by $\large -0.5$ is equivalent to multiplying by $\large -2$, which we can combine with the $\large 8$ term to get $\large -16$:

$$\large f(n) = c \times 1.5^n + -16 (1 - 1.5^n)$$

We could stop here, but let's distribute the $\large -16$:

$$\large f(n) = c \times 1.5^n - 16 + 16 \times 1.5^n$$

Since we have two terms being added which each are multiplied by $\large 1.5^n$, we factor it out like so:

$$\large f(n) = (c+16) \times 1.5^n - 16$$

This looks how we probably expected it would, and it is relatively easy to deal with. Now let's try to apply this to our original problem. The first thing we want to do, is find an $\large n$ for which $\large x \ge f(n)$, where $\large x$ is the requested new_capacity. To find $\large n$, we have to isolate it on the right-hand side:

$$\large \begin{equation} \begin{split}
x &\ge (c+16) \times 1.5^n - 16 \ & \small \texttt{(+16 to both sides)} \ x + 16 &\ge (c+16) \times 1.5^n \ & \small \texttt{(divide by (c+16) on both sides)} \ \frac{x + 16}{c+16} &\ge 1.5^n \ & \small \texttt{(take the log of both sides)} \ \log{\left(\frac{x + 16}{c+16}\right)} &\ge \log{(1.5^n)} \ & \small \texttt{(property of logarithms on the right-hand side)} \ \log{\left(\frac{x + 16}{c+16}\right)} &\ge n\log{(1.5)} \ & \small \texttt{(divide each side by log(1.5))} \ \frac{ \log{\left(\frac{x + 16}{c+16}\right)}}{\log{(1.5)}} &\ge n \ & \small \texttt{(property of logarithms on the left-hand side)} \ \log{1.5}{\left(\frac{x + 16}{c+16}\right)} &\ge n \ & \small \texttt{(property of logarithms on the left-hand side)} \ \log{1.5}{(x + 16)} - \log_{1.5}{(c + 16)} &\ge n \ \end{split} \end{equation}$$

Now this is usable for our problem. We can compute $\large n$ by doing $\large \lceil\log{1.5}{(x + 16)} - \log{1.5}{(c + 16)}\rceil$, then plug that in to $\large n$ in $\large f(n) = (c+16) \times 1.5^n - 16$. Together, that's:

$$\large (c+16) \times 1.5^{\lceil(\log{1.5}{(x + 16)} - \log{1.5}{(c + 16)})\rceil} - 16$$

For those of you who skipped ahead, $\large c$ is self.capacity and $\large x$ is new_capacity, and this formula gives you the better_capacity. Note that this formula will give numbers a bit higher than the original while loop, because the original while loop loses some 0.5's when dividing an odd number by 2.


Now, the remaining question is how to compute the previous expression, or rather, an approximation of it, efficiently.

Sadly, efficiently computing the base 1.5 logarithm of an integer is not ideal. If we were allowed to change the original problem such that we could use the base 2 logarithm, that would be much easier to compute, that's just @typeInfo(@TypeOf(c)).Int.bits - 1 - @clz(c) (obviously, this would be an integer, so we should be careful on how the flooring of the true answer affects rounding error). Let's use this information to make an approximation. Using the change of base property of logarithms, we can rewrite the equation like so:

$$\large \frac{\log_2{(x + 16)}}{\log_2{1.5}} - \frac{\log_2{(c + 16)}}{\log_2{1.5}}$$

$\large \frac{1}{\log_2{1.5}} \approx 1.7095112913514547$, so we can approximate the above expression like so:

$$\large (\log_2{(x + 16)} - \log_2{(c + 16)}) \times 1.7095112913514547$$

As hinted to earlier, we can find $\large \lceil\log_2{(x + 16)}\rceil - \lceil\log_2{(c + 16)}\rceil$ by doing @clz(c + 15) - @clz(x + 15). Note that the terms are now in reverse order because the answer returned by @clz(b) is actually $\large 63 - \lfloor\log_2{b}\rfloor$. We also subtracted 1 from 16 because we probably want the ceil base 2 logarithm instead, and the algorithm for that is 64 - @clz(x - 1). (64 - @clz((x + 16) - 1)) - (64 - @clz((c + 16) - 1)) reduces to @clz(c + 15) - @clz(x + 15). That's slightly different than what we want, which is to ceil only after multiplying by $\large 1.7095112913514547$, but if we're careful about which way the rounding works, we should be fine.


The other thing I notice is that $\large 1.5^{n}$ is equivalent to $\large \frac{3^{n}}{2^{n}}$. Of course, dividing by $\large 2^{n}$ is just a right shift, which means we could do the following once we determine the value of $\large n$.

$$\large (((c+16) \times 3^{n}) \gg n) - 16$$

Of course, this will have additional overflow potential even when the right shift would have taken us back into the range of usize. Maybe we could expand to 128 bits for the multiply. Alternatively, for powers of 1.5 where the decimal point is less relevant, we'd probably be fine with a lookup table or something so our code could be (c + 16) * powers[...]) - 16


One thing we could do is work backwards, changing $\large 1.7095112913514547$ to a nicer number like $\large 1.5$ or $\large 2$. Let's pick $\large 2$. To make it so we would multiply by $\large 2$ instead, we would change our recursive sequence to:

$$\large \begin{equation} \begin{split}
U_0 &= \texttt{capacity}\ Un &= U{n-1} \times \sqrt 2 + 8\ \end{split} \end{equation}$$

This works because $\large \frac{1}{\log_2{\sqrt 2}}$ is $\large 2$. This is still pretty close to our original formula, as $\large \sqrt 2 \approx 1.41421$ and $\large 1.41421 \approx 1.5$. If we did the same steps as before, $\large \frac{8}{1 - \sqrt 2} \approx 19.313708498984756$ would be in all the places where we had $\large 16$ in our original equations. Let's round that up to $\large 20$ this time, since we rounded $\large 1.5$ down to $\large \sqrt 2$. To do that, we change the common difference of $\large 8$ to $\large -20 (1 - \sqrt 2)$, which is about $\large 8.2842712474619$. Reminder: the point here is that when we divide this value by $\large (1 - \sqrt 2)$, we get $\large -20$ rather than the $\large -16$ we had earlier.

$$\large \begin{equation} \begin{split}
U_0 &= \texttt{capacity}\ Un &= U{n-1} \times \sqrt 2 - 20 (1 - \sqrt 2)\ Un &\approx U{n-1} \times 1.41421 + 8.2842712474619\ \end{split} \end{equation}$$

By the same steps shown above, this gives us the coveted:

$$\large (c+20) \times \sqrt 2^{\lceil 2(\log_2{(x + 20)} - \log_2{(c + 20)})\rceil} - 20$$

I.e.:

$$\large (c+20) \times \sqrt 2^{\lceil \log{\sqrt 2}{(x + 20)} - \log{\sqrt 2}{(c + 20)}\rceil} - 20$$

As mentioned before, we can find $\large \lceil\log_2{(x + 20)}\rceil - \lceil\log2{(c + 20)}\rceil$ by doing @clz(c + 19) - @clz(x + 19). However, this is not close enough to $\large \lceil \log{\sqrt 2}{(x + 20)} - \log{\sqrt 2}{(c + 20)}\rceil$ for our use-case because we need at least the granularity of a $\large log{\sqrt 2}$ either way (ideally, we could use even more precision in some cases). This could be accomplished via a lookup table, or via another approximation. As an approximation, we could pretend that each odd power of $\large \sqrt 2$ is half-way between powers of $\large 2$ that fall on even powers of $\large \sqrt 2$. If you think about it, this is kind of semantically in line with what we are doing when we subtract the @clz of two numbers, now with slightly more granularity. By AND'ing the bit directly under the most significant bit with the most significant bit, then moving it to the 1's place, we can add it (or OR it) with double the bit index of the highest set bit:

fn log_sqrt_2_int(x: u64) u64{
    const fls = 63 - @clz(x);
    return fls * 2 | std.math.shr(u64, x & (x << 1), fls);
}

This is kind of what we are looking for, with a bit more accuracy than before. We can also scale this up even more if desired, by multiplying by 4, moving the bit directly under the MSB to the 2's place, and moving the bit two positions below the MSB to the one's place. However, because we are taking the clz of values that have 19 added to them, we can guarantee there will always be at least 4 bits under the most significant bit that we can use to interpolate between powers of $\large \sqrt 2$, so we can scale to 4 extra bits to improve accuracy:

// Kinda an approximation of 16 log2(x). Will be divided by 8 to approximate 2 log2(x).
fn log_approx_helper(x: usize) usize {
    const COMPLEMENT = @typeInfo(usize).Int.bits - 1;
    const BITS_TO_PRESERVE = @as(comptime_int, COMPLEMENT - @clz(@as(usize, 19)));

    const fls = @intCast(std.math.Log2Int(usize), COMPLEMENT - @clz(x)); // min 4
    const x_with_msb_unset = @as(usize, 1) << fls ^ x;
    const pack_bits_under_old_msb = x_with_msb_unset >> fls - BITS_TO_PRESERVE;
    return @as(usize, fls) << BITS_TO_PRESERVE | pack_bits_under_old_msb;
}

// usage:
const n = 1 + (log_approx_helper(x + 19) - log_approx_helper(c + 19)) / 8;
// i.e.:
const n = 1 + (log_approx_helper(new_capacity + 19) - log_approx_helper(self.capacity + 19)) / 8;

Now that we have calculated $\large n$, the last problem is approximating $\large \sqrt 2^n$. Again, this can be done with a lookup table, or we could pretend once more that odd powers of $\large \sqrt 2$ are directly in the middle of powers of $\large 2$. Let's try that.

fn approx_sqrt_2_pow(y: u7) u64 {
    // y is basically a fixed point integer, with the 1's place being after the decimal point
    const shift = @intCast(u6, y >> 1);
    return (@as(u64, 1) << shift) | (@as(u64, y & 1) << (shift -| 1));
}

And here are the estimates versus what we would get from std.math.pow(f64, std.math.sqrt2, n):

pow: est vs double  <- format
√2^1: 1 vs 1.4142135623730951
√2^3: 3 vs 2.8284271247461907
√2^5: 6 vs 5.656854249492383
√2^7: 12 vs 11.313708498984768
√2^9: 24 vs 22.627416997969544
√2^11: 48 vs 45.254833995939094
√2^13: 96 vs 90.50966799187822
√2^15: 192 vs 181.01933598375646
√2^17: 384 vs 362.038671967513
√2^19: 768 vs 724.0773439350261
√2^21: 1536 vs 1448.1546878700526
√2^23: 3072 vs 2896.3093757401057
√2^25: 6144 vs 5792.618751480213
√2^27: 12288 vs 11585.237502960428
√2^29: 24576 vs 23170.475005920864
More ``` √2^31: 49152 vs 46340.950011841735 √2^33: 98304 vs 92681.9000236835 √2^35: 196608 vs 185363.80004736703 √2^37: 393216 vs 370727.60009473417 √2^39: 786432 vs 741455.2001894685 √2^41: 1572864 vs 1482910.4003789374 √2^43: 3145728 vs 2965820.800757875 √2^45: 6291456 vs 5931641.601515752 √2^47: 12582912 vs 11863283.203031506 √2^49: 25165824 vs 23726566.406063017 √2^51: 50331648 vs 47453132.81212604 √2^53: 100663296 vs 94906265.62425211 √2^55: 201326592 vs 189812531.24850425 √2^57: 402653184 vs 379625062.4970086 √2^59: 805306368 vs 759250124.9940174 √2^61: 1610612736 vs 1518500249.9880352 √2^63: 3221225472 vs 3037000499.976071 √2^65: 6442450944 vs 6074000999.952143 √2^67: 12884901888 vs 12148001999.904287 √2^69: 25769803776 vs 24296003999.808582 √2^71: 51539607552 vs 48592007999.61717 √2^73: 103079215104 vs 97184015999.23438 √2^75: 206158430208 vs 194368031998.46878 √2^77: 412316860416 vs 388736063996.9377 √2^79: 824633720832 vs 777472127993.8755 √2^81: 1649267441664 vs 1554944255987.7512 √2^83: 3298534883328 vs 3109888511975.503 √2^85: 6597069766656 vs 6219777023951.008 √2^87: 13194139533312 vs 12439554047902.018 √2^89: 26388279066624 vs 24879108095804.043 √2^91: 52776558133248 vs 49758216191608.09 √2^93: 105553116266496 vs 99516432383216.22 √2^95: 211106232532992 vs 199032864766432.47 √2^97: 422212465065984 vs 398065729532865.06 √2^99: 844424930131968 vs 796131459065730.2 √2^101: 1688849860263936 vs 1592262918131461 √2^103: 3377699720527872 vs 3184525836262922.5 √2^105: 6755399441055744 vs 6369051672525847 √2^107: 13510798882111488 vs 12738103345051696 √2^109: 27021597764222976 vs 25476206690103400 √2^111: 54043195528445952 vs 50952413380206810 √2^113: 108086391056891904 vs 101904826760413630 √2^115: 216172782113783808 vs 203809653520827300 √2^117: 432345564227567616 vs 407619307041654700 √2^119: 864691128455135232 vs 815238614083309600 √2^121: 1729382256910270464 vs 1630477228166619600 √2^123: 3458764513820540928 vs 3260954456333240000 √2^125: 6917529027641081856 vs 6521908912666482000 √2^127: 13835058055282163712 vs 13043817825332965000 ```

Well, that's pretty much all I have so far. With a little polishing, this is the code I ended up with:

const std = @import("std");
const BITS_TO_PRESERVE = @as(comptime_int, (@typeInfo(usize).Int.bits - 1) - @clz(@as(usize, 20)));

// Kinda an approximation of 16 log2(x). Will be divided by 8 to approximate 2 log2(x).
fn log_approx_helper(x: usize) usize {
    const fls = @intCast(std.math.Log2Int(usize), (@typeInfo(usize).Int.bits - 1) - @clz(x)); // [4, 63]
    const x_with_msb_unset = x ^ @as(usize, 1) << fls;
    const pack_bits_under_old_msb = x_with_msb_unset >> fls - BITS_TO_PRESERVE;
    return @as(usize, fls) << BITS_TO_PRESERVE | pack_bits_under_old_msb; // [16, 1023] on 64-bit
}

/// Modify the array so that it can hold at least `new_capacity` items.
/// Invalidates pointers if additional memory is needed.
export fn ensureTotalCapacity(capacity: usize, new_capacity: usize) usize {
//    if (@sizeOf(T) == 0) { 
//         self.capacity = math.maxInt(usize); 
//         return; 
//     } 
    const power = 1 + (log_approx_helper(new_capacity +| 20) -| log_approx_helper(capacity +| 20)) / 8;
    const shift = @intCast(std.math.Log2Int(usize), power >> 1);
    const approx_sqrt_2_power = (@as(usize, 1) << shift) | (@as(usize, power & 1) << (shift -| 1));
    return @max(capacity +| (capacity / 2 + 8), (capacity +| 20) *| approx_sqrt_2_power - 20);
}

// side note: I decided to just always use 20 instead of 19 where applicable, because it is a mostly trivial difference
// and we can reuse `capacity +| 20` in 2 locations.

Here is the godbolt link: https://zig.godbolt.org/z/fTeaY1Gfn

I would be interested to hear if anyone else has any ideas on how to improve this (beyond optimizations already performed automagically by LLVM). I would also be happy if someone volunteered to benchmark this, either in a microbenchmark or in a large system that makes heavy or light use of this function. Thanks for reading!

- Validark

ghost commented 1 year ago

Is there any particular reason to scale the existing capacity in exact steps of $1.5c+8$? If the only purpose of the loop is to prevent reallocation with an overly small increment, one could simply do this instead:

const better_capacity = @max(new_capacity, self.capacity +| (self.capacity / 2 + 8))
Snektron commented 1 year ago

You could also just pregenerate the entire series at compile time

squeek502 commented 1 year ago

Is there any particular reason to scale the existing capacity in exact steps of $1.5c+8$? If the only purpose of the loop is to prevent reallocation with an overly small increment, one could simply do this instead:

const better_capacity = @max(new_capacity, self.capacity +| (self.capacity / 2 + 8))

This behaves differently for resizing from e.g. 0 to x.

Here's the current pattern from calling ensureTotalCapacity with current capacity of 0:

requested capacity => resulting capacity
0...8 => 8
9...20 => 20
21...38 => 38
39...65 => 65
66...105 => 105
106...165 => 165
166...255 => 255
256...390 => 390
391...593 => 593
594...897 => 897
898...1353 => 1353
1354...2037 => 2037
2038...3063 => 3063
3064...4602 => 4602
4603...6911 => 6911
6912...10374 => 10374
10375...15569 => 15569
15570...23361 => 23361
23362...35049 => 35049
35050...52581 => 52581
52582...78879 => 78879
...

The code in the OP is fairly similar:

requested capacity => resulting capacity
0...7 => 8
8...19 => 20
20...35 => 40
36...59 => 60
60...91 => 100
92...139 => 140
140...203 => 220
204...299 => 300
300...427 => 460
428...619 => 620
620...875 => 940
876...1259 => 1260
1260...1771 => 1900
1772...2539 => 2540
2540...3563 => 3820
3564...5099 => 5100
5100...7147 => 7660
7148...10219 => 10220
10220...14315 => 15340
14316...20459 => 20460
20460...28651 => 30700
28652...40939 => 40940
40940...57323 => 61420
57324...81899 => 81900
...

But with @max(new_capacity, self.capacity +| (self.capacity / 2 + 8)):

requested capacity => resulting capacity
0...8 => 8
9...10 => 10
11...12 => 12
13...14 => 14
15...16 => 16
17...18 => 18
19...20 => 20
21...22 => 22
23...24 => 24
25...26 => 26
...
78319...78320 => 78320
78321...78322 => 78322
78323...78324 => 78324
...

You could also just pregenerate the entire series at compile time

As I understand it, this would involve an impossibly large number of possibilities, as better_capacity depends on both the self.capacity and the new_capacity. For example, here's how the current better_capacity behaves:

self.capacity -> new_capacity = better_capacity
0 -> 101 = 105
50 -> 101 = 132
100 -> 101 = 158
andrewrk commented 1 year ago

Here is the godbolt link: https://zig.godbolt.org/z/fTeaY1Gfn

Here is a comparison of before and after:

before

ensureTotalCapacity:
        cmp     x0, x1
        b.hs    .LBB0_2
.LBB0_1:
        lsr     x8, x0, #1
        add     x8, x8, #8
        adds    x8, x0, x8
        csinv   x0, x8, xzr, lo
        cmp     x0, x1
        b.lo    .LBB0_1
.LBB0_2:
        ret

after

ensureTotalCapacity:
        adds    x8, x1, #20
        mov     w9, #63
        csinv   x8, x8, xzr, lo
        adds    x10, x0, #20
        clz     x11, x8
        csinv   x10, x10, xzr, lo
        sub     w12, w9, w11
        clz     x14, x10
        sub     w9, w9, w14
        mov     w13, #1
        mov     w15, #59
        sub     w11, w15, w11
        sub     w14, w15, w14
        and     x15, x12, #0x3f
        lsl     x12, x13, x12
        eor     x8, x12, x8
        lsl     x12, x13, x9
        and     x9, x9, #0x3f
        eor     x12, x12, x10
        lsr     x8, x8, x11
        lsr     x11, x12, x14
        orr     x8, x8, x15, lsl #4
        orr     x9, x11, x9, lsl #4
        subs    x8, x8, x9
        csel    x8, xzr, x8, lo
        lsr     x8, x8, #3
        add     x8, x8, #1
        ubfx    x9, x8, #1, #6
        lsr     x11, x8, #1
        subs    w9, w9, #1
        and     x8, x8, #0x1
        csel    w9, wzr, w9, lo
        lsl     x11, x13, x11
        lsl     x8, x8, x9
        lsr     x9, x0, #1
        orr     x8, x8, x11
        add     x9, x9, #8
        adds    x9, x0, x9
        umulh   x11, x10, x8
        csinv   x9, x9, xzr, lo
        mul     x8, x10, x8
        cmp     xzr, x11
        csinv   x8, x8, xzr, eq
        sub     x8, x8, #20
        cmp     x9, x8
        csel    x0, x9, x8, hi
        ret

The branching code already wins in terms of code size, as well as source code simplicity.

I suspect if you measured this, the existing machine code would win in terms of perf as well due to the branch being extremely predictable and the fact that it rarely branches backwards.

Also, the exact equation (n / 2 + 8) is not important. What's important is that it increases super-linearly so that appending is amortized O(1), balanced against not increasing too fast otherwise it wastes memory.

Cool math tho.