rust-lang / rust

Empowering everyone to build reliable and efficient software.
https://www.rust-lang.org
Other
98.34k stars 12.72k forks source link

Flaky constant propagation results in loss of SIMD optimization #121511

Open KOROBYAKA opened 8 months ago

KOROBYAKA commented 8 months ago

We have observed unexpected loss of vectorization effort by the compiler when compiling the code below. Target was amd64 machine with AVX2 instruction set supported.

type T = u32;
#[inline(never)]
fn check_dumb(n:u32, max_val:T)->Option<(T,T,T,T)>{                
    for x in 1..max_val{
        for y in 1..max_val{        
            for z in 0..max_val{
                if x.pow(n) + y.pow(n) == z.pow(n){
                    return Some((x,y,z,n as T));
                }                                

            }
        }
    }
    return None
}

#[inline(never)]
fn check_smart(n:u32, max_val:T)->Option<(T,T,T,T)>{

    const STRIDE:usize=24;

    let zarr_base = {
        let mut t = [0;STRIDE];
        for i in 0..STRIDE{
            t[i] = 1 + i as T;
        }
        t
    };

    for x in 1..max_val{
        //let xpow = x.pow(n); // manually "lifting" this operation from inner loop "fixes" the optimizer
        for y in 1..max_val{
            //let ypow = y.pow(n);  // manually "lifting" this operation from inner loop "fixes" the optimizer
            let mut zarr = zarr_base;

            for z_i in 0..max_val/STRIDE as T{

                for i in 0..STRIDE{
                    zarr[i] += STRIDE as T;
                }

                let mut zarr3 = [0;STRIDE];
                for i in 0..STRIDE{
                    zarr3[i] = zarr[i] *zarr[i]*zarr[i] ;
                }     

                let found = zarr3.iter().any(|&i|{
                    x.pow(n) + y.pow(n) == i           
                    // xpow + ypow == i           // see comments above for why this might work
                });

                if found {    
                    return Some((x,y,z_i * STRIDE as T,n as T));
                }                                

            }
        }
    }
    return None
}

fn main() {

    let n = 3; //this const value should apply for the entire main() function
    {
        let s = std::time::Instant::now();
        println!("Check_dumb n={} returned {:?}",n, check_dumb(n, 1000));
        println!("Time taken {:?}", s.elapsed());
    }
    {
        //Commenting out the line below results in loss of vectorization in check_smart!!!
        let n = 3;
        let s = std::time::Instant::now();
        println!("Check with simd n={} returned {:?}",n, check_smart(n, 1000));
        println!("Time taken {:?}", s.elapsed());
    }
}

I expected to see this happen: the function check_smart should leverage vectorization and SIMD instructions, while function check_dumb should be unable to do so. I should not have to specify n=3 twice in main(), as value of n from outer scope should be propagated into function check_smart. The vectorization in check_smart should happen for any value of n (even when n is not known ahead of time).

Instead, this happened: the ability of the compiler to actually perform SIMD vectorization of check_smart seems to depend on whether x.pow(n) can panic. When value of n is not propagated into the function, the optimizer fails to perform vectorization resulting in massive loss of performance.

Meta

rustc --version --verbose:

rustc 1.76.0 (07dca489a 2024-02-04)
binary: rustc
commit-hash: 07dca489ac2d933c78d3c5158e3f43beefeb02ce
commit-date: 2024-02-04
host: x86_64-unknown-linux-gnu
release: 1.76.0
LLVM version: 17.0.6

Nightly compiler exhibits identical behavior.

CPU is Intel(R) Core(TM) i5-8350U CPU @ 1.70GHz

alexpyattaev commented 8 months ago

Just tested on a slightly different CPU target: with line setting n=3 commented out:

Check_dumb n=3 returned None
Time taken 915.227343ms
Check with simd n=3 returned None
Time taken 4.634040885s 

Wow 4.6 seconds is slower than javascript! Same but with line present:

Check_dumb n=3 returned None
Time taken 907.446979ms
Check with simd n=3 returned None
Time taken 502.535157ms

CPU used was 13th Gen Intel(R) Core(TM) i5-1335U, optimization was set for native:

[target.x86_64-unknown-linux-gnu]
rustflags = ["-Ctarget-cpu=native"]

Diving a bit into the assembly of check_smart, it seems that failure to figure out an appropriate const value for n results in pow function being splattered around the code of the hot loop, killing any sort of hope for performance. It is very odd that x.pow(n) is not moved out of the loop over z though, in which case it should not matter how it is implemented.

erikdesjardins commented 8 months ago

Strangely, this only seems to happen with target-cpu=native (even on godbolt which is Zen 3), not if you use the baseline or explicitly specify target-cpu=raptorlake (for the i5-1335U) or any other specific CPU I've tried (haswell, core-avx2, znver3, etc.): https://godbolt.org/z/jqrs7bP5W.

(I was expecting this to be the same root cause as https://github.com/rust-lang/rust/issues/112478#issuecomment-1586332367, where printing a value prevents LLVM from seeing that it's constant, since a pointer escapes into the printing machinery that could be used to mutate it, but it seems weirder than that.)

alexpyattaev commented 8 months ago

That is mighty odd indeed. On my machine the bug persists with raptorlake setting:

$$ ~/fermat (master)> cargo clean && RUSTFLAGS="-Ctarget-cpu=raptorlake" cargo  run --release
     Removed 11 files, 844.0KiB total
   Compiling fermat v0.1.0 (/home/headhunter/fermat)
    Finished release [optimized] target(s) in 0.26s
     Running `target/release/fermat`
Check_dumb n=3 returned None
Time taken 932.948488ms
Check with simd n=3 returned None
Time taken 3.440539885s

LLVM complains about "failed to hoist load with loop-invariant address because load is conditionally executed" near calls to pow(n), but I'm not sure how to extract more info from its optimization remarks...

I've also tried targeting some other architectures that my machines could run with same result:

 $$ ~/fermat (master)> cargo clean && RUSTFLAGS="-Ctarget-cpu=cascadelake" cargo  run --release
     Removed 11 files, 844.0KiB total
   Compiling fermat v0.1.0 (/home/headhunter/fermat)
    Finished release [optimized] target(s) in 0.25s
     Running `target/release/fermat`
Check_dumb n=3 returned None
Time taken 910.116577ms
Check with simd n=3 returned None
Time taken 4.317862738s
$$ ~/fermat (master) [101]> cargo clean && RUSTFLAGS="-Ctarget-cpu=znver2" cargo  run --release
     Removed 7 files, 3.3KiB total
   Compiling fermat v0.1.0 (/home/headhunter/fermat)
    Finished release [optimized] target(s) in 0.25s
     Running `target/release/fermat`
Check_dumb n=3 returned None
Time taken 924.990783ms
Check with simd n=3 returned None
Time taken 4.279846815s
$$ ~/fermat (master)> cargo clean && RUSTFLAGS="-Ctarget-cpu=znver3" cargo  run --release
     Removed 11 files, 844.0KiB total
   Compiling fermat v0.1.0 (/home/headhunter/fermat)
    Finished release [optimized] target(s) in 0.29s
     Running `target/release/fermat`
Check_dumb n=3 returned None
Time taken 954.212518ms
Check with simd n=3 returned None
Time taken 3.432712025s
alexpyattaev commented 8 months ago

Also I've checked your hypothesis about printing killing optimization, and it seems to hold up. Modifying the callsite for check_smart to print n just before call to the function results in failure to optimize.

{
        let n = 3;
        println!("Value of n is {}", n.clone()); //removing clone() here will break optimization no matter how dumb that is!
        let s = std::time::Instant::now();
        println!("Check with simd n={} returned {:?}",n, check_smart(n, 1000));
        println!("Time taken {:?}", s.elapsed());
    }

The deeper question now is as follows: why would it even matter which value n takes? It is entirely irrelevant to the optimizations within the inner loop over z (as both x and y are constant there, and thus x.pow(n) can be safely hoisted out of the loop over z). Even if n could take some entirely bizarre value that would panic the call to pow, there is no point in keeping it within the inner loop, or am I missing something here? Does it actually matter where exactly panic is generated? Edit: u32.pow() does not panic, even on overflow. So it should be always safe to hoist calls to it, no matter value for n. Edit2: with new code structure, optimizer remark changed to "failed to hoist load with loop-invariant address because load is conditionally executed", but resulting asm is pretty much the same garbage. And the remark seems to have nothing to do with the actual problem at hand.

alexpyattaev commented 8 months ago

Managed to reproduce in Godbolt. Works reliably with different targets also (not just native). The difference in perf on AMD is quite dramatic, from 400ms to 9.5 seconds is crazy slow.

alexpyattaev commented 7 months ago

Tested again on latest rustc 1.78.0-nightly (c67326b06 2024-03-15) , no changes so far in either direction.

alexpyattaev commented 4 months ago

Tested again on 1.80, the issue no longer presents. Whatever made it go away is unclear. rustc 1.80.0-beta.3 (105fc5ccc 2024-06-14) - does not have issue rustc 1.79.0 (129f3b996 2024-06-10) - still has the issue

I'll keep watching if it reemerges.

alexpyattaev commented 3 months ago

Latest nightly rustc 1.82.0-nightly (92c6c0380 2024-07-21) has the problem again with the exact same symptoms as before. There is a clear regression here.