rust-lang / rust

Empowering everyone to build reliable and efficient software.
https://www.rust-lang.org
Other
99.2k stars 12.81k forks source link

Strange register allocation causes slowdown #121529

Open SeeSpring opened 9 months ago

SeeSpring commented 9 months ago

I tried this code:

#![feature(portable_simd)]
use std::simd::{num::SimdFloat, Simd};
#[inline(never)]
pub fn px(xx: [Simd<u32, 8>; 8], ax: &[Simd<u32, 8>], bx: &[Simd<u32, 8>], cx: &[Simd<u32, 8>]) {
    for a in ax {
        let ay: [_; 8] = std::array::from_fn(|i| (Simd::splat(0xf << i) ^ a) + xx[i]);
        for b in bx {
            let byi: &[_; 8] = &std::array::from_fn(|i| ((Simd::splat(0xf << i) ^ b) + ay[i]));
            for &c in cx {
                let cy0 = (Simd::splat(0xf << 0) & c) + l(&byi[0]);
                let cy1 = (Simd::splat(0xf << 1) & c) + l(&byi[1]);
                let cy2 = (Simd::splat(0xf << 2) & c) + l(&byi[2]);
                let cy3 = (Simd::splat(0xf << 3) & c) + l(&byi[3]);
                let cy4 = (Simd::splat(0xf << 4) & c) + (l(&byi[4]) ^ (cy0));
                let cy5 = (Simd::splat(0xf << 5) & c) + (l(&byi[5]) ^ (cy1));
                let cy6 = (Simd::splat(0xf << 6) & c) + (l(&byi[6]) ^ (cy2));
                let cy7 = (Simd::splat(0xf << 7) & c) + (l(&byi[7]) ^ (cy3));
                let cy = [
                    Simd::<f32, 8>::from_bits(cy0),
                    Simd::<f32, 8>::from_bits(cy1),
                    Simd::<f32, 8>::from_bits(cy2),
                    Simd::<f32, 8>::from_bits(cy3),
                    Simd::<f32, 8>::from_bits(cy4),
                    Simd::<f32, 8>::from_bits(cy5),
                    Simd::<f32, 8>::from_bits(cy6),
                    Simd::<f32, 8>::from_bits(cy7),
                ];
                let cya = cy[..4]
                    .into_iter()
                    .rev()
                    .fold(Simd::splat(0.), |a, x| a + x);
                let cyb = cy[4..]
                    .into_iter()
                    .rev()
                    .fold(Simd::splat(0.), |a, x| a + x);
                std::hint::black_box(cya + cyb);
            }
        }
    }
}
fn l<T: Copy>(x: &T) -> T {
    // return unsafe { (x as *const T).read_volatile() };
    return *x;
}
pub fn main() {
    use rand::Rng;
    let mut rng = <rand_chacha::ChaCha8Rng as rand::SeedableRng>::seed_from_u64(808605057454428838);
    let xx: [u32; 8 * 8] = rng.gen();
    let yy: [u32; 8 * 512] = rng.gen();
    let xx = unsafe { std::mem::transmute(xx) };
    let yy  = unsafe { std::mem::transmute::<_, [Simd<u32, 8>; 512]>(yy) };
    let yy = yy.as_slice();
    let start = std::time::Instant::now();
    px(xx, yy, yy, yy);
    dbg!(std::time::Instant::now().duration_since(start));
}

Godbolt

I expected to see this happen: The are no unnecessary vmovaps in the innermost loop

Instead, this happened: The are unnecessary vmovaps in the innermost loop

.LBB0_11:
        vmovdqa ymm9, ymmword ptr [r9 + r11]
        vpand   ymm10, ymm9, ymm0
        vpaddd  ymm10, ymm10, ymmword ptr [rsp + 800]
        vpxor   xmm8, xmm8, xmm8
        vpand   ymm11, ymm9, ymm3
        vpaddd  ymm11, ymm11, ymmword ptr [rsp + 768]
        vpand   ymm12, ymm9, ymm4
        vpaddd  ymm12, ymm12, ymmword ptr [rsp + 736]
        vpand   ymm13, ymm9, ymm5
        vpaddd  ymm13, ymm13, ymmword ptr [rsp + 704]
        vpand   ymm14, ymm9, ymm6
        vmovdqa ymm7, ymm6                          ; <--
        vmovdqa ymm6, ymm5                          ; <--
        vmovdqa ymm5, ymm4                          ; <--
        vmovdqa ymm4, ymm3                          ; <--
        vmovdqa ymm3, ymm0                          ; <--
        vpxor   ymm0, ymm10, ymmword ptr [rsp + 672]
        vpaddd  ymm0, ymm14, ymm0
        vpand   ymm14, ymm9, ymm15
        vmovdqa ymm1, ymm15                         ; <--
        vpxor   ymm15, ymm11, ymmword ptr [rsp + 640]
        vpaddd  ymm14, ymm15, ymm14
        vpand   ymm15, ymm9, ymm2
        vpxor   ymm2, ymm12, ymmword ptr [rsp + 608]
        vpaddd  ymm2, ymm15, ymm2
        vpand   ymm9, ymm9, ymmword ptr [rsp + 32]
        vpxor   ymm15, ymm13, ymmword ptr [rsp + 576]
        vpaddd  ymm9, ymm15, ymm9
        vmovdqa ymm15, ymm1                         ; <--
        vaddps  ymm13, ymm13, ymm8
        vaddps  ymm12, ymm13, ymm12
        vaddps  ymm11, ymm12, ymm11
        vaddps  ymm10, ymm11, ymm10
        vpxor   xmm1, xmm1, xmm1
        vaddps  ymm9, ymm9, ymm1
        vaddps  ymm2, ymm9, ymm2
        vaddps  ymm2, ymm14, ymm2
        vaddps  ymm0, ymm2, ymm0
        vmovaps ymm2, ymmword ptr [rsp]
        vaddps  ymm0, ymm10, ymm0
        vmovaps ymmword ptr [rsp + 832], ymm0
        vmovdqa ymm0, ymm3                          ; <--
        vmovdqa ymm3, ymm4                          ; <--
        vmovdqa ymm4, ymm5                          ; <--
        vmovdqa ymm5, ymm6                          ; <--
        vmovdqa ymm6, ymm7                          ; <--
        add     r11, 32
        cmp     rax, r11
        jne     .LBB0_11

Using .read_volatile() improves performance. According to UICA this is a 12.0/10.14=1.18 difference in cycles taken for the inner loop and 1.09 +- 0.22 according to hyperfine when measuring the executable.

Meta

rustc --version --verbose:

rustc 1.78.0-nightly (381d69953 2024-02-24)
binary: rustc
commit-hash: 381d69953bb7c3390cec0fee200f24529cb6320f
commit-date: 2024-02-24
host: x86_64-unknown-linux-gnu
release: 1.78.0-nightly
LLVM version: 18.1.0
SeeSpring commented 9 months ago

Also see https://github.com/llvm/llvm-project/issues/66837 https://github.com/llvm/llvm-project/issues/81391