rust-lang / rust

Empowering everyone to build reliable and efficient software.
https://www.rust-lang.org
Other
98.43k stars 12.73k forks source link

Optimize away bounds check in loop indexing into slice, given an assertion #71997

Open joshtriplett opened 4 years ago

joshtriplett commented 4 years ago

I wrote a simple loop indexing into a slice, to test rustc's ability to optimize away bounds checks if it knows an index is in bounds. Even with this very simple test case, I can't seem to get rust to omit the bounds checks no matter what assert! I add. (I know that I could trivially write this code using iterators instead, but I'm trying to figure out rust's ability to optimize here.)

Test case (edited since original posting to augment the assert! further):

#[no_mangle]
fn f(slice: &[u64], start: usize, end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && start < slice.len() && end <= slice.len());
    for i in start..end {
        total += slice[i];
    }
    total
}

I put that into the compiler explorer, with -O, and the resulting assembly looks like this:

f:
        push    rax
        cmp     rdx, rcx
        jae     .LBB5_8
        cmp     rsi, rdx
        jbe     .LBB5_8
        cmp     rsi, rcx
        jb      .LBB5_8
        xor     eax, eax
.LBB5_4:
        cmp     rdx, rsi
        jae     .LBB5_7
        add     rax, qword ptr [rdi + 8*rdx]
        add     rdx, 1
        cmp     rcx, rdx
        jne     .LBB5_4
        pop     rcx
        ret
.LBB5_7:
        lea     rax, [rip + .L__unnamed_5]
        mov     rdi, rdx
        mov     rdx, rax
        call    qword ptr [rip + core::panicking::panic_bounds_check@GOTPCREL]
        ud2
.LBB5_8:
        call    std::panicking::begin_panic
        ud2

Based on the x86 calling convention, rdi contains the slice base address, rsi contains the slice length, rdx contains start, and rcx contains end.

So, the first three comparisons verify the assertion and jump to .LBB5_8 if it fails, to panic.

Then inside the loop, there's still another comparison of rdx to rsi, and a jump to .LBB5_7 to panic if out of bounds.

As far as I can tell, that's exactly the same comparison. Shouldn't rustc be able to optimize away that bounds check?

Things I've tested:

Ideally, rustc should be able to optimize away the bounds check in the loop, based on the assertion. Even better would be if rustc could hoist the bounds check out of the loop even without the assertion, but that seems like a harder problem.

ecstatic-morse commented 4 years ago

Did you mean end < slice.len()?

joshtriplett commented 4 years ago

@ecstatic-morse I tried including that as well, among other permutations of the assert condition, and it didn't change the in-loop bounds check.

the8472 commented 4 years ago

Adding std::intrinsics::assume(i < slice.len()) inside the loop eliminates the check and vectorizes it.

RalfJung commented 4 years ago

Looks like LLVM is unable to propagate what it learned on an earlier conditional, into the later loop?

ecstatic-morse commented 4 years ago

@joshtriplett Nevertheless, you should update your example. Without checking end < slice.len(), f(&[0], 0, 1337) will trigger an index-out-of-bounds panic without hitting the assertion.

joshtriplett commented 4 years ago

@ecstatic-morse Done; edited the code and provided the new corresponding assembly. Doesn't affect the code of the loop, which still includes the bounds check.

the8472 commented 4 years ago

This works and considering that an exclusive range is used it presumably also is the intended use of that function

#[no_mangle]
fn f(slice: &[u64], start: usize, end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && end <= slice.len());
    let end = std::cmp::min(slice.len(), end);
    for i in start..end {
        total += slice[i];
    }
    total
}
joshtriplett commented 4 years ago

@the8472 I can confirm that that code eliminates the bounds check. And interestingly, if I reverse the two arguments to min, that does not eliminate the bounds check.

This works (no bounds check, vectorized):

#[no_mangle]
fn f(slice: &[u64], start: usize, end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && end <= slice.len());
    let end = std::cmp::min(slice.len(), end);
    for i in start..end {
        total += slice[i];
    }
    total
}

This doesn't work (includes the bounds check in the loop):

#[no_mangle]
fn f(slice: &[u64], start: usize, end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && end <= slice.len());
    let end = std::cmp::min(end, slice.len());
    for i in start..end {
        total += slice[i];
    }
    total
}
joshtriplett commented 4 years ago

The only differences between the two argument orders to min are the direction of the comparison and which argument gets returned if they compare equal. And sure enough, I can confirm that open-coding the equivalent only works if we use slice.len() in place of end in the case where end == slice.len().

This works (no bounds check in the loop):

#[no_mangle]
fn f(slice: &[u64], start: usize, mut end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && start < slice.len() && end <= slice.len());
    if end >= slice.len() { end = slice.len(); }
    for i in start..end {
        total += slice[i];
    }
    total
}

This doesn't work (bounds check in the loop):

#[no_mangle]
fn f(slice: &[u64], start: usize, mut end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && start < slice.len() && end <= slice.len());
    if end > slice.len() { end = slice.len(); }
    for i in start..end {
        total += slice[i];
    }
    total
}
joshtriplett commented 4 years ago

It looks like there are two separate bugs here: 1) rustc should optimize away the bounds check given just the assert!, without needing the redundant call to min or the equivalent open-coded change to end. 2) The optimization shouldn't require assigning end = slice.len() in the case where end == slice.len().

joshtriplett commented 4 years ago

I produced a corresponding set of (naively translated) C++ test cases, and confirmed the same behavior from clang trunk.

The following C++ code does not optimize away the bounds check:

#include <cassert>
#include <cstdint>
#include <vector>
using namespace std;

uint64_t f(vector<uint64_t> slice, size_t start, size_t end)
{
    uint64_t total = 0;
    assert(start < end && start < slice.size() && end <= slice.size());
    for (size_t i = start; i < end; i++) {
        total += slice.at(i);
    }
    return total;
}

Nor does this:

#include <cassert>
#include <cstdint>
#include <vector>
using namespace std;

uint64_t f(vector<uint64_t> slice, size_t start, size_t end)
{
    uint64_t total = 0;
    assert(start < end && start < slice.size() && end <= slice.size());
    if (end > slice.size())
        end = slice.size();
    for (size_t i = start; i < end; i++) {
        total += slice.at(i);
    }
    return total;
}

But this does (note the >= in the if):

#include <cassert>
#include <cstdint>
#include <vector>
using namespace std;

uint64_t f(vector<uint64_t> slice, size_t start, size_t end)
{
    uint64_t total = 0;
    assert(start < end && start < slice.size() && end <= slice.size());
    if (end >= slice.size())
        end = slice.size();
    for (size_t i = start; i < end; i++) {
        total += slice.at(i);
    }
    return total;
}
the8472 commented 4 years ago

A more intuitive way to achieve the desired result:

#[no_mangle]
fn f(slice: &[u64], start: usize, end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && start < slice.len() && end <= slice.len());
    for i in (start..end).take_while(|&i| i < slice.len()) {
        total += slice[i];
    }
    total
}
joshtriplett commented 4 years ago

@the8472 As mentioned in the original comment, I understand that I could rewrite the loop in a way that Rust can figure out how to optimize. However, I'd like to see Rust optimizing the original code, which it has enough information to do; I expect that doing so will substantially improve quite a bit of existing Rust code.

tmandry commented 4 years ago

Here's another example:

pub fn copy_t(dest: &mut [u8], src: &[u8]) {
    let len = std::cmp::min(dest.len(), src.len());
    for i in 0..len {
        dest[i] = src[i]
    }
}

which compiles to

example::copy_t:
        push    rax
        cmp     rsi, rcx
        mov     r8, rsi
        cmova   r8, rcx
        test    r8, r8
        je      .LBB2_5
        xor     r9d, r9d
.LBB2_2:
        cmp     rcx, r9
        je      .LBB2_6
        cmp     rsi, r9
        je      .LBB2_7
        movzx   eax, byte ptr [rdx + r9]
        mov     byte ptr [rdi + r9], al
        add     r9, 1
        cmp     r9, r8
        jb      .LBB2_2
.LBB2_5:
        pop     rax
        ret
.LBB2_6:
        lea     rdx, [rip + .L__unnamed_1]
        mov     rdi, rcx
        mov     rsi, rcx
        call    qword ptr [rip + core::panicking::panic_bounds_check@GOTPCREL]
        ud2
.LBB2_7:
        lea     rdx, [rip + .L__unnamed_2]
        mov     rdi, rsi
        call    qword ptr [rip + core::panicking::panic_bounds_check@GOTPCREL]
        ud2

I tried different combinations of assertions from this issue and none of them worked for this one. Using unchecked indexes optimizes the loop to a memcpy:

pub fn copy_s(dest: &mut [u8], src: &[u8]) {
    let len = std::cmp::min(dest.len(), src.len());
    for i in 0..src.len() {
        unsafe {
            *dest.get_unchecked_mut(i) = *src.get_unchecked(i);
        }
    }
}
example::copy_s:
        test    rcx, rcx
        je      .LBB1_2
        push    rax
        mov     rsi, rdx
        mov     rdx, rcx
        call    qword ptr [rip + memcpy@GOTPCREL]
        add     rsp, 8
.LBB1_2:
        ret
tesuji commented 4 years ago

How about this? https://godbolt.org/z/TQSMyv

pub fn copy_c(dest: &mut [u8], src: &[u8]) {
    let len = std::cmp::min(dest.len(), src.len());
    let (left, _) = dest.split_at_mut(len);
    left.copy_from_slice(&src[..len]);
}
tmandry commented 4 years ago

Yes, copy_from_slice works, but only if you know ahead of time that it's going to be a simple byte-for-byte copy.

What I'm personally interested in is trying to get generic code that (after inlining) compiles down to the equivalent of the above example to optimize.

nikic commented 4 years ago

@tmandry Here is what goes into induction variable simplication for your case, after a bit of cleanup:

define void @test([0 x i8]* nocapture nonnull align 1 %dest.0, i64 %dest.1, [0 x i8]* noalias nocapture nonnull readonly align 1 %src.0, i64 %src.1) {
start:
  %i = icmp ugt i64 %dest.1, %src.1
  %umin = select i1 %i, i64 %src.1, i64 %dest.1
  %i1 = icmp eq i64 %umin, 0
  br i1 %i1, label %bb7, label %bb9.preheader

bb9.preheader:                                    ; preds = %start
  br label %bb9

bb7.loopexit:                                     ; preds = %bb11
  br label %bb7

bb7:                                              ; preds = %bb7.loopexit, %start
  ret void

bb9:                                              ; preds = %bb11, %bb9.preheader
  %iv = phi i64 [ %iv.inc, %bb11 ], [ 0, %bb9.preheader ]
  %iv.inc = add nuw i64 %iv, 1
  %_23 = icmp ult i64 %iv, %src.1
  br i1 %_23, label %bb10, label %panic

bb10:                                             ; preds = %bb9
  %_26 = icmp ult i64 %iv, %dest.1
  br i1 %_26, label %bb11, label %panic1

bb11:                                             ; preds = %bb10
  %i3 = getelementptr inbounds [0 x i8], [0 x i8]* %src.0, i64 0, i64 %iv
  %_20 = load i8, i8* %i3, align 1
  %i4 = getelementptr inbounds [0 x i8], [0 x i8]* %dest.0, i64 0, i64 %iv
  store i8 %_20, i8* %i4, align 1
  %i5 = icmp ult i64 %iv.inc, %umin
  br i1 %i5, label %bb9, label %bb7.loopexit

panic:                                            ; preds = %bb9
  %iter.sroa.0.015.lcssa = phi i64 [ %iv, %bb9 ]
  tail call void @abort(i64 %iter.sroa.0.015.lcssa)
  unreachable

panic1:                                           ; preds = %bb10
  %iter.sroa.0.015.lcssa16 = phi i64 [ %iv, %bb10 ]
  tail call void @abort(i64 %iter.sroa.0.015.lcssa16)
  unreachable
}

declare void @abort(i64)

The thing to note is that %umin != 0 is checked on entry and the loop uses a postinc exit condition %iv+1 < %umin.

The relevant SCEV parts are:

  %iv = phi i64 [ %iv.inc, %bb11 ], [ 0, %bb9.preheader ]
  -->  {0,+,1}<nuw><%bb9> U: [0,-1) S: [0,-1)       Exits: ((-1 + (%dest.1 umin %src.1)) umin %dest.1 umin %src.1)      LoopDispositions: { %bb9: Computable }
  %iv.inc = add nuw i64 %iv, 1
  -->  {1,+,1}<nuw><%bb9> U: [1,0) S: [1,0)     Exits: (1 + ((-1 + (%dest.1 umin %src.1)) umin %dest.1 umin %src.1))        LoopDispositions: { %bb9: Computable }
...
  exit count for bb9: %src.1
  exit count for bb10: %dest.1
  exit count for bb11: (-1 + (%dest.1 umin %src.1))

The -1 is what obscures things here, because it could be overflowing. SCEV is not capable of retaining that a subtraction is NUW, because it canonicalizes to additions.

the8472 commented 4 years ago

What I'm personally interested in is trying to get generic code that (after inlining) compiles down to the equivalent of the above example to optimize.

Is this sufficiently generic? It compiles to a memcpy

pub fn generic_copy<T: Clone>(dest: &mut [T], src: &[T]) {
    let len = std::cmp::min(dest.len(), src.len());
    let (dest, _) = dest.split_at_mut(len);
    let src = &src[..len];

    for i in 0..src.len() {
        dest[i] = src[i].clone()
    }
}

pub fn concrete(dest: &mut [u8], src: &[u8]) {
    generic_copy(dest, src)
}
Arnavion commented 3 years ago

I found this today as part of minimizing some other code:

// Bounds check not elided
pub fn cond_inline(s: &[bool; 2], cond: bool) -> bool {
    for i in (if cond { 0..=1 } else { 0..=0 }) {
        if s[i] {
            return true;
        }
    }

    false
}

// Bounds check elided
pub fn cond_outside(s: &[bool; 2], cond: bool) -> bool {
    if cond {
        for i in 0..=1 {
            if s[i] {
                return true;
            }
        }
    }
    else {
        for i in 0..=0 {
            if s[i] {
                return true;
            }
        }
    }

    false
}

The former retains the bounds check while the latter elids it. There is even more information here than the OP's case because all the lengths are known at compile-time, so it's even more surprising to me.

Is it covered by this issue? (It sounds like https://github.com/rust-lang/rust/issues/71997#issuecomment-625768824 ) Or should I open a separate one?

(1.48.0 stable)

jrmuizel commented 3 years ago

Adding -mllvm -enable-constraint-elimination to the first C++ example eliminates the bounds check: https://gcc.godbolt.org/z/13bMsa7hM.

It doesn't help the second example and -C passes=constraint-elimination doesn't help the Rust example.

alex commented 3 years ago

Converted to C++ (turned out basically identical to @joshtriplett's) and filed as an llvm optimizer bug: https://bugs.llvm.org/show_bug.cgi?id=49885

jrmuizel commented 3 years ago

I filed a gcc bug for fun too: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=99966