rust-lang / rust

Empowering everyone to build reliable and efficient software.
https://www.rust-lang.org
Other
96.7k stars 12.49k forks source link

Missed optimization: loop with increasing index doesn't elide bounds check #74688

Open paolobarbolini opened 4 years ago

paolobarbolini commented 4 years ago

I tried this code and many variations of it and I found out this doesn't elide the bounds check on &buf[i..].

pub fn problematic(buf: &[u8]) -> &[u8] {
    let mut i = 0;
    for b in buf {
        if *b == 0x00 {
            return &buf[i..];
        }

        i += 1;
    }

    &[]
}
Assembly Output ```assembly example::problematic: pushq %rax leaq .L__unnamed_1(%rip), %rax xorl %edx, %edx testq %rsi, %rsi je .LBB0_6 xorl %ecx, %ecx .LBB0_2: cmpb $0, (%rdi,%rcx) je .LBB0_4 addq $1, %rcx cmpq %rcx, %rsi jne .LBB0_2 .LBB0_6: popq %rcx retq .LBB0_4: cmpq %rsi, %rcx ja .LBB0_7 subq %rcx, %rsi addq %rcx, %rdi movq %rsi, %rdx movq %rdi, %rax popq %rcx retq .LBB0_7: leaq .L__unnamed_2(%rip), %rdx movq %rcx, %rdi callq *core::slice::slice_index_order_fail@GOTPCREL(%rip) ud2 .L__unnamed_3: .ascii "./example.rs" .L__unnamed_2: .quad .L__unnamed_3 .asciz "\f\000\000\000\000\000\000\000\005\000\000\000\025\000\000" .L__unnamed_1: ```

Manually checking that i is ok elides the bounds check and ends up being optimized away!?

pub fn optimized(buf: &[u8]) -> &[u8] {
    let mut i = 0;
    for b in buf {
        // this gets optimized away?
        if i >= buf.len() { break; }

        if *b == 0x00 {
            return &buf[i..];
        }

        i += 1;
    }

    &[]
}
Assembly Output ```assembly example::optimized: leaq .L__unnamed_1(%rip), %rax testq %rsi, %rsi je .LBB0_5 movq %rsi, %rdx .LBB0_2: cmpb $0, (%rdi) je .LBB0_3 addq $1, %rdi addq $-1, %rdx jne .LBB0_2 .LBB0_5: xorl %edx, %edx retq .LBB0_3: movq %rdi, %rax retq .L__unnamed_1: ```

This third example generates the same assembly as the second one, so I think it's really optimizing away the manual check on the second example?

pub fn optimized_unsafe(buf: &[u8]) -> &[u8] {
    let mut i = 0;
    for b in buf {
        if *b == 0x00 {
            return unsafe { buf.get_unchecked(i..) };
        }

        i += 1;
    }

    &[]
}
Assembly Output ```assembly example::optimized_unsafe: leaq .L__unnamed_1(%rip), %rax testq %rsi, %rsi je .LBB0_5 movq %rsi, %rdx .LBB0_2: cmpb $0, (%rdi) je .LBB0_3 addq $1, %rdi addq $-1, %rdx jne .LBB0_2 .LBB0_5: xorl %edx, %edx retq .LBB0_3: movq %rdi, %rax retq .L__unnamed_1: ```

Possibly related: #74186

95th commented 4 years ago

https://github.com/rust-lang/rust/issues/65969 is similar but more complex scenario of this.

the8472 commented 4 years ago

This works and has even fewer branches:

pub fn problematic(buf: &[u8]) -> &[u8] {
    for i in 0 .. buf.len() {
        if buf[i] == 0x00 {
            return &buf[i..];
        }
    }

    &[]
}

In my experience all the variables whose bounds checks you want to have eliminated should be derived from by the slice bounds (either directly or through arithmetic). It often doesn't seem to recognize that the value range of another variable in the loop is constrained in such a way that it also falls within the bounds.

the8472 commented 4 years ago

See also #73396

paolobarbolini commented 4 years ago

In my experience all the variables whose bounds checks you want to have eliminated should be derived from by the slice bounds (either directly or through arithmetic)

I agree @the8472, the original issue was from something like this:

pub fn problematic(buf: &[u8]) -> &[u8] {
    for (i, b) in buf.iter().enumerate() {
        // this was a matcher, simplified it for this
        if *b == 0x00 {
            return &buf[i..];
        }
    }

    &[]
}

And the solution was to write it as:

pub fn solution(mut buf: &[u8]) -> &[u8] {
    while !buf.is_empty() {
        // just simplifying this here, in the final code it's a match checking more things
        if buf[0] == 0x00 {
            break;
        } else {
            buf = &buf[1..];
        }
    }

    buf
}

I thought the if i >= buf.len() { break; } thing would be interesting, but maybe it doesn't matter anyway since seeing from the issue you linked to it has already been fixed in LLVM?

EDIT: maybe not