rust-lang / rust

Empowering everyone to build reliable and efficient software.
https://www.rust-lang.org
Other
96.87k stars 12.51k forks source link

LLVM generates branch soup for array partition point #129530

Open ds84182 opened 3 weeks ago

ds84182 commented 3 weeks ago

I tried this code:

pub fn partition_point_short_array_n<const N: usize>(array: &[usize; N]) -> usize {
    // array.partition_point(|x| *x != 0)
    let mut i = 0;
    while i < N {
        if array[i] == 0 { break }
        i += 1;
    }
    i
}

pub fn partition_point_short_array(array: &[usize; 24]) -> usize {
    partition_point_short_array_n(array)
}

I expected to see this happen: Assembly composed of cmov and/or adox instruction. Or at least mov + je to a single exit branch.

Instead, this happened:

playground::partition_point_short_array:
    cmpq    $0, (%rdi)
    je  .LBB0_1
    cmpq    $0, 8(%rdi)
    je  .LBB0_3
    cmpq    $0, 16(%rdi)
    je  .LBB0_5
    cmpq    $0, 24(%rdi)
    je  .LBB0_7
    cmpq    $0, 32(%rdi)
    je  .LBB0_9
    cmpq    $0, 40(%rdi)
    je  .LBB0_11
    cmpq    $0, 48(%rdi)
    je  .LBB0_13
    cmpq    $0, 56(%rdi)
    je  .LBB0_15
    cmpq    $0, 64(%rdi)
    je  .LBB0_17
    cmpq    $0, 72(%rdi)
    je  .LBB0_19
    cmpq    $0, 80(%rdi)
    je  .LBB0_21
    cmpq    $0, 88(%rdi)
    je  .LBB0_23
    cmpq    $0, 96(%rdi)
    je  .LBB0_25
    cmpq    $0, 104(%rdi)
    je  .LBB0_27
    cmpq    $0, 112(%rdi)
    je  .LBB0_29
    cmpq    $0, 120(%rdi)
    je  .LBB0_31
    cmpq    $0, 128(%rdi)
    je  .LBB0_33
    cmpq    $0, 136(%rdi)
    je  .LBB0_35
    cmpq    $0, 144(%rdi)
    je  .LBB0_37
    cmpq    $0, 152(%rdi)
    je  .LBB0_39
    cmpq    $0, 160(%rdi)
    je  .LBB0_41
    cmpq    $0, 168(%rdi)
    je  .LBB0_43
    cmpq    $0, 176(%rdi)
    je  .LBB0_45
    cmpq    $1, 184(%rdi)
    movl    $24, %eax
    sbbq    $0, %rax
    retq

.LBB0_1:
    xorl    %eax, %eax
    retq

.LBB0_3:
    movl    $1, %eax
    retq

.LBB0_5:
    movl    $2, %eax
    retq

.LBB0_7:
    movl    $3, %eax
    retq

.LBB0_9:
    movl    $4, %eax
    retq

.LBB0_11:
    movl    $5, %eax
    retq

.LBB0_13:
    movl    $6, %eax
    retq

.LBB0_15:
    movl    $7, %eax
    retq

.LBB0_17:
    movl    $8, %eax
    retq

.LBB0_19:
    movl    $9, %eax
    retq

.LBB0_21:
    movl    $10, %eax
    retq

.LBB0_23:
    movl    $11, %eax
    retq

.LBB0_25:
    movl    $12, %eax
    retq

.LBB0_27:
    movl    $13, %eax
    retq

.LBB0_29:
    movl    $14, %eax
    retq

.LBB0_31:
    movl    $15, %eax
    retq

.LBB0_33:
    movl    $16, %eax
    retq

.LBB0_35:
    movl    $17, %eax
    retq

.LBB0_37:
    movl    $18, %eax
    retq

.LBB0_39:
    movl    $19, %eax
    retq

.LBB0_41:
    movl    $20, %eax
    retq

.LBB0_43:
    movl    $21, %eax
    retq

.LBB0_45:
    movl    $22, %eax
    retq

First occurs in 1.19.0 with the alternative code snippet:

#[no_mangle]
pub fn partition_point_short_array(array: &[usize; 24]) -> usize {
    let mut i = 0;
    while i < 24 {
        if array[i] == 0 { break }
        i += 1;
    }
    i
}
saethlin commented 3 weeks ago

This is loop unrolling. So...

I expected to see this happen: Assembly composed of cmov and/or adox instruction. Or at least mov + je to a single exit branch.

Why do you think this would be better? If you think the unrolled version is slower, do you have a benchmark? If you think the code size is problematic, does -Copt-level=s do what you want? Is that optimization setting a better fit for your codebase?

the8472 commented 3 weeks ago

std's partition_point implements a binary search, your simplified version is a linear scan. LLVM has no optimizations that would turn a linear search into a binary one since it wouldn't be able to know that the array is sorted/partitioned.

workingjubilee commented 3 weeks ago

Is there a reason that std's partition_point does not serve well on your code, @ds84182?

Is this code not faster for the small array size you are concerned about?

theemathas commented 3 weeks ago

The "a bunch of jumps" approach will need fewer comparisons than the "a bunch of cmovs" approach if the zero is usually near the start of the array. Since the compiler doesn't know if this is often the case in your workload, it assumes that you know what you're doing and therefore preserves what your code does.

cvijdea-bd commented 3 weeks ago

FWIW if want a branchless implementation, then SIMD will likely be 2-5x faster than the scalar version, depending on what target features you can afford to use. Reducing the array element size from usize would also help.

#![feature(portable_simd)]
use std::simd::cmp::SimdPartialEq;
use std::simd::Simd;

pub fn partition_point_simd(array: &[usize; 24]) -> usize {
    let mut array_zext = [0; 32];
    array_zext[..24].copy_from_slice(array);

    let array = Simd::from_array(array_zext);
    let mask = array.simd_eq(Simd::splat(0));
    mask.to_bitmask().trailing_zeros() as usize
}

However, in the latest nightly, std::binary_search is also branchless (#128254), and it turns out it's the fastest implementation on inputs where the partition point is randomly distributed in 0..=24:

pub fn partition_point_std(array: &[usize; 24]) -> usize {
    array.partition_point(|x| *x != 0)
}

It's also remarkably small:

playground::partition_point_std:
    mov rax, qword ptr [rcx + 96]
    test rax, rax
    mov edx, 12
    cmove rdx, rax
    lea rax, [rdx + 6]
    mov r8d, eax
    cmp qword ptr [rcx + 8*r8], 0
    cmovne rdx, rax
    lea rax, [rdx + 3]
    mov r8d, eax
    cmp qword ptr [rcx + 8*r8], 0
    cmovne rdx, rax
    lea r8, [rdx + 1]
    cmp qword ptr [rcx + 8*rdx + 8], 0
    cmove r8, rdx
    lea rax, [r8 + 1]
    cmp qword ptr [rcx + 8*r8 + 8], 0
    cmove rax, r8
    cmp qword ptr [rcx + 8*rax], 1
    sbb rax, -1
    ret

[usize; 24] arg with -Ctarget-cpu=x86-64

test bench_partition_point_branchless ... bench:   1,209,903.75 ns/iter (+/- 683,113.25) = 15869 MB/s
test bench_partition_point_linear     ... bench:   1,889,390.00 ns/iter (+/- 371,780.75) = 10162 MB/s
test bench_partition_point_simd       ... bench:     792,851.25 ns/iter (+/- 637,814.88) = 24216 MB/s
test bench_partition_point_std        ... bench:     558,115.62 ns/iter (+/- 306,352.19) = 34401 MB/s

[u32; 24] arg with -Ctarget-cpu=x86-64-v3

test bench_partition_point_branchless ... bench:     312,764.69 ns/iter (+/- 104,581.22) = 30694 MB/s
test bench_partition_point_linear     ... bench:   1,153,777.50 ns/iter (+/- 239,229.50) = 8320 MB/s
test bench_partition_point_simd       ... bench:     198,122.81 ns/iter (+/- 39,717.12) = 48454 MB/s
test bench_partition_point_std        ... bench:     235,001.56 ns/iter (+/- 98,562.62) = 40850 MB/s

partition_point_branchless is the scalar equivalent of partition_point_simd - it's faster in the second case because it gets auto-vectorized, but badly.

fn partition_point_branchless<const N: usize>(array: &[u32; N]) -> usize {
    let mut mask: u32 = 0;
    for i in 0..N {
        if array[i] == 0 {
            mask |= 1 << i;
        }
    }
    return mask.trailing_zeros() as usize;
}