rust-lang / rust

Empowering everyone to build reliable and efficient software.
https://www.rust-lang.org
Other
96.69k stars 12.49k forks source link

slice::iter() does not preserve number of iterations information for optimizer causing unneeded bounds checks #75935

Open sdroege opened 4 years ago

sdroege commented 4 years ago

Godbolt link to the code below: https://rust.godbolt.org/z/aKf3Wq

pub fn foo1(x: &[u32], y: &[u32]) -> u32 {
    let mut sum = 0;
    let chunk_size = y.len();
    for (c, y) in y.iter().enumerate() {
        for chunk in x.chunks_exact(chunk_size) {
            sum += chunk[c] + y;
        }
    }
    sum
}

This code has a bounds check for chunk[c] although c < chunk_size by construction.

The same code a bit more convoluted gets rid of the bounds check

pub fn foo2(x: &[u32], y: &[u32]) -> u32 {
    let mut sum = 0;
    let chunk_size = y.len();
    for c in 0..chunk_size {
        let y = y[c];
        for chunk in x.chunks_exact(chunk_size) {
            sum += chunk[c] + y;
        }
    }
    sum
}

It seems like the information that 0 <= c < y.len() gets lost for the optimizer when going via y.iter().enumerate(). So this is unrelated to chunks_exact() specifically but I can't come up with an equivalent example without it.

edit: As noticed in https://github.com/rust-lang/rust/issues/75935#issuecomment-680807329, this can be worked around by defining a custom slice iterator that does counting of elements instead of working with an end pointer.

The problem is that the slice::iter() works with an end pointer to know when the iteration can stop and keeps no information around for the optimizer that it's actually going to iterate exactly N times. Unclear to me how this information can be preserved without changing how the iterator works, which will probably have other negative effects.

edit2:

As noticed in https://github.com/rust-lang/rust/pull/77822 this happens with C++/C too and can also simplified a lot on the Rust side

pub fn foo(y: &[u32]) {
    let mut x = 0;
    for (c, _y) in y.iter().enumerate() {
        assert!(c < y.len());
        x = c;
    }
    assert!(x == y.len());
}
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>

#include <vector>

void foo1(const uint32_t *y, size_t y_len) {
  const uint32_t *y_end = y + y_len;
  size_t c = 0;
  for (const uint32_t *y_iter = y; y_iter != y_end; y_iter++, c++) {
    assert(c < y_len);
  }
  assert(c == y_len);
}

void foo2(const std::vector<uint32_t>& y) {
    size_t c = 0;
    for (auto y_iter: y) {
        assert(c < y.size());
        c++;
    }
    assert(c == y.size());
}

void foo3(const std::vector<uint32_t>& y) {
    size_t c = 0;
    for (auto y_iter = y.cbegin(); y_iter != y.cend(); y_iter++, c++) {
        assert(c < y.size());
    }
    assert(c == y.size());
}

edit3: This is now also reported to https://bugs.llvm.org/show_bug.cgi?id=48965

sdroege commented 4 years ago

I should probably add that while this code is very contrived, it's based on real code that shows the same behaviour.

tesuji commented 4 years ago

Maybe a duplicate of #74938: An upgrade to LLVM 12 or so is needed to fix the issue.

sdroege commented 4 years ago

I'll try adding only that change to the LLVM version used by rustc. Let's see if that solves anything (or works at all :) ).

sdroege commented 4 years ago

No, does not work at all. Gives a segfault in exactly that function that is changed inside LLVM.

sdroege commented 4 years ago

Got it to work. It doesn't fix this issue here, but it fixes #75936 . This one here is still valid.

sdroege commented 4 years ago

It seems like the information that 0 <= c < y.len() somehow gets lost for the optimizer when going via y.iter().enumerate().

My guess is that this is because the slice iterators don't go via a counter but instead via an end pointer, so it's not obvious anymore that it's just iterating exactly self.len() times.

sdroege commented 4 years ago

Yes, going with a simple iterator that counts instead gets rid of the bounds check. Code:

```rust pub fn foo1(x: &[u32], y: &[u32]) -> u32 { let mut sum = 0; let chunk_size = y.len(); for (c, y) in Iter::new(y).enumerate() { for chunk in x.chunks_exact(chunk_size) { sum += chunk[c] + y; } } sum } struct Iter<'a> { ptr: *const u32, len: usize, phantom: std::marker::PhantomData<&'a [u32]>, } impl<'a> Iter<'a> { fn new(v: &'a [u32]) -> Iter<'a> { Iter { ptr: v.as_ptr(), len: v.len(), phantom: std::marker::PhantomData, } } } impl<'a> Iterator for Iter<'a> { type Item = &'a u32; fn next(&mut self) -> Option<&'a u32> { unsafe { if self.len == 0 { return None; } let item = &*self.ptr; self.ptr = self.ptr.add(1); self.len -= 1; Some(item) } } } ```
tesuji commented 4 years ago

You could update the issue description for new information.

sdroege commented 4 years ago

You could update the issue description for new information.

Indeed, thanks. Done!

cynecx commented 4 years ago

Is there any particular reason why the std-implementation (slice::Iter) is doing iteration through end pointer equality compared to the counting variant?

sdroege commented 4 years ago

Is there any particular reason why the std-implementation (slice::Iter) is doing iteration through end pointer equality compared to the counting variant?

I don't know the history, but in theory one instruction less per iteration (one pointer addition vs. one pointer addition and one counter addition). And it might be taken advantage of in some specialized impls but I don't know.

Might be worth looking at what std::vector iterators in C++ are doing, I'd hope those are optimizing well with clang++.

tesuji commented 4 years ago

Is there any particular reason why the std-implementation (slice::Iter) is doing iteration through end pointer equality compared to the counting variant?

The comment in code says that it's because of an optimization for ZST: https://github.com/rust-lang/rust/blob/118860a7e76daaac3564c7655d46ac65a14fc612/library/core/src/slice/mod.rs#L4009-L4014

sdroege commented 4 years ago

The comment in code says that it's because of an optimization for ZST:

If you had a counter that would work basically the same way, you'd just check idx==len or remainder==0 or similar. For ZST the current encoding seems just like a way to make it possible to not worry about ZST special cases in most other places of the code.

sdroege commented 4 years ago

I can do an implementation of slice::iter() that does counting next week, but what would be the best way to check this doesn't cause any performance regressions elsewhere? How are such things usually checked (i.e. is there some extensive benchmark suite that I could run, ...)?

sdroege commented 3 years ago

And another variant with specialization of the Enumerate iterator, which should also catch various other cases (e.g. the Chunks and ChunksExact iterators on slices). This yields the most optimal code so far: no bounds checks, unrolled and auto-vectorized nicely.

Check assembly here, the new one is foo2().

The std::intrinsics::assume() in the specialized impl is the part that makes it work nicely.

I'll create a PR for this later.

```rust pub fn foo2(x: &[u32], y: &[u32]) -> u32 { let mut sum = 0; let chunk_size = y.len(); for (c, y) in Enumerate::new(y.iter()) { for chunk in x.chunks_exact(chunk_size) { sum += chunk[c] + y; } } sum } struct Enumerate { iter: I, count: usize, len: usize, } impl Enumerate { fn new(iter: I) -> Self { EnumerateImpl::new(iter) } } impl Iterator for Enumerate where I: Iterator, { type Item = (usize, ::Item); #[inline] fn next(&mut self) -> Option<(usize, ::Item)> { EnumerateImpl::next(self) } } // Enumerate specialization trait #[doc(hidden)] trait EnumerateImpl { type Item; fn new(iter: I) -> Self; fn next(&mut self) -> Option<(usize, Self::Item)>; } impl EnumerateImpl for Enumerate where I: Iterator, { type Item = I::Item; default fn new(iter: I) -> Self { Enumerate { iter, count: 0, len: 0, // unused } } #[inline] default fn next(&mut self) -> Option<(usize, I::Item)> { let a = self.iter.next()?; let i = self.count; // Possible undefined overflow. self.count += 1; Some((i, a)) } } impl EnumerateImpl for Enumerate where // FIXME: Should probably be TrustedRandomAccess because otherwise size_hint() might be expensive? I: std::iter::TrustedLen + ExactSizeIterator + Iterator, { fn new(iter: I) -> Self { let len = iter.size_hint().0; Enumerate { iter, count: 0, len, } } #[inline] fn next(&mut self) -> Option<(usize, I::Item)> { let a = self.iter.next()?; unsafe { std::intrinsics::assume(self.count < self.len); } let i = self.count; // Possible undefined overflow. self.count += 1; Some((i, a)) } } ```
sdroege commented 3 years ago

I've updated the issue description with a minimal testcase in Rust, C/C++ and as discussed in https://github.com/rust-lang/rust/pull/77822 will report this to LLVM.

sdroege commented 3 years ago
the8472 commented 3 years ago

I can do an implementation of slice::iter() that does counting next week, but what would be the best way to check this doesn't cause any performance regressions elsewhere? How are such things usually checked (i.e. is there some extensive benchmark suite that I could run, ...)?

Requesting a perf run on the PR would be a start, which runs an extensive performance test compiling (but not benchmarking) various crates. It might miss unusual number-crunching uses of slices though.

https://perf.rust-lang.org/

sdroege commented 3 years ago

Thanks, but considering that C++ iterators work the same way, it's probably safer to keep it with this pattern and wait for LLVM to solve that. I would assume that patterns (also) used by C++ are more likely to be optimized better.

scottmcm commented 2 months ago

Is there any particular reason why the std-implementation (slice::Iter) is doing iteration through end pointer equality compared to the counting variant?

Because in cases where the iterator doesn't get SRoA'd away, slice::Iter::next on pointers only needs to write one value to memory (the start pointer), whereas a counted implementation needs to update two (increase the pointer and decrease the count): https://rust.godbolt.org/z/9zdYnMd1E.

When computing the length of a slice iterator we use both sub nuw and udiv exact to tell LLVM that it doesn't need to worry about weird cases, so hypothetically it should have everything it needs here. Of course LLVM does tend to optimize away that information in the next it generates...

One thing we could to better is at least tell LLVM that the pointers are aligned even when not read, though. I wish we had a better core pointer type that could actually do that.