rust-lang / rust

Empowering everyone to build reliable and efficient software.
https://www.rust-lang.org
Other
96.73k stars 12.5k forks source link

Weird AVX 512 code generated with std::simd when using -Zbuild-std #129293

Open cvijdea-bd opened 3 weeks ago

cvijdea-bd commented 3 weeks ago

Discussed on Zulip: https://rust-lang.zulipchat.com/#narrow/stream/257879-project-portable-simd/topic/simd.3A.3AMask.20codegen.20on.20avx512

I tried this code (Godbolt link):

#![feature(portable_simd)]
use std::simd::u8x32;
use std::simd::Simd;
use std::simd::cmp::SimdPartialOrd;

#[inline(never)]
pub fn test_lt_select(idxs: u8x32) -> u8x32 {
    idxs.simd_lt(Simd::splat(32u8))
        .select(idxs, Simd::splat(u8::MAX))
}

I expected to see this happen: simd_lt + simd_select is compiled to clean 3 instruction sequence (AVX512: vpcmpltub + vpcmpeqd + vmovdqu8-with-mask, AVX2: vpmaxub + vpcmpeqb + vpor) - this is the case with pre-built std, see Godbolt

Instead, this happened: with -Zbuild-std, vpcmpltub is followed by lots of redundant shuffling of mask registers

RUSTFLAGS=-Ctarget-cpu=sapphirerapids cargo build --release -Z build-std --target x86_64-unknown-linux-gnu

Messy assembly ``` 00000000000596b0 <_ZN9test_simd14test_lt_select17h4bc8c118b05e9d0dE>: 596b0: 62 f3 7d 28 3e 05 25 vpcmpltub k0,ymm0,YMMWORD PTR [rip+0xfffffffffffab725] # 4de0 596b7: b7 fa ff 01 unsafe { Self(core::intrinsics::simd::simd_bitmask(value), PhantomData) } 596bb: c4 e3 79 31 c8 08 kshiftrd k1,k0,0x8 596c1: c4 e3 79 31 d0 18 kshiftrd k2,k0,0x18 596c7: c5 fb 93 c2 kmovd eax,k2 596cb: c4 e3 79 31 d0 10 kshiftrd k2,k0,0x10 core::intrinsics::simd::simd_select_bitmask( 596d1: c5 f9 93 ca kmovb ecx,k2 596d5: c5 f9 6e c8 vmovd xmm1,eax 596d9: c5 fb 93 c1 kmovd eax,k1 596dd: c4 e3 71 20 c8 01 vpinsrb xmm1,xmm1,eax,0x1 596e3: c4 e2 79 31 c9 vpmovzxbd xmm1,xmm1 596e8: c4 e2 71 47 0d 7f bb vpsllvd xmm1,xmm1,XMMWORD PTR [rip+0xfffffffffffabb7f] # 5270 596ef: fa ff 596f1: c1 e1 10 shl ecx,0x10 596f4: c5 f9 7e c8 vmovd eax,xmm1 596f8: 09 c8 or eax,ecx 596fa: c4 e3 79 16 c9 01 vpextrd ecx,xmm1,0x1 59700: 09 c1 or ecx,eax 59702: c5 f9 93 c0 kmovb eax,k0 59706: 09 c8 or eax,ecx 59708: c5 fb 92 c8 kmovd k1,eax 5970c: c5 f5 76 c9 vpcmpeqd ymm1,ymm1,ymm1 59710: 62 f1 7f 29 6f c8 vmovdqu8 ymm1{k1},ymm0 59716: c5 fd 7f 0f vmovdqa YMMWORD PTR [rdi],ymm1 #[inline(never)] pub fn test_lt_select(idxs: u8x32) -> u8x32 { idxs.simd_lt(Simd::splat(32u8)) .select(idxs, Simd::splat(u8::MAX)) } 5971a: c5 f8 77 vzeroupper 5971d: c3 ret ```

Without -Zbuild-std, the generated LLVM IR is a beautiful icmp ult followed by select.

; test_simd::test_lt_select
; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind nonlazybind willreturn memory(argmem: write) uwtable
define internal fastcc void @_ZN9test_simd14test_lt_select17h0358d9f4e1cf8f67E(ptr dead_on_unwind noalias nocapture noundef writable writeonly align 32 dereferenceable(32) %_0, <32 x i8> %idxs.0.val) unnamed_addr #3 !dbg !2029 {
start:
    #dbg_declare(ptr undef, !2031, !DIExpression(), !2032)
    #dbg_declare(ptr undef, !2033, !DIExpression(), !2037)
    #dbg_declare(ptr undef, !2039, !DIExpression(), !2044)
    #dbg_value(<32 x i8> <i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32>, !2036, !DIExpression(), !2046)
  %0 = icmp ult <32 x i8> %idxs.0.val, <i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32, i8 32>, !dbg !2047
    #dbg_value(<32 x i8> poison, !2042, !DIExpression(), !2048)
    #dbg_value(<32 x i8> <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>, !2043, !DIExpression(), !2048)
  %1 = select <32 x i1> %0, <32 x i8> %idxs.0.val, <32 x i8> <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>, !dbg !2049
  store <32 x i8> %1, ptr %_0, align 32, !dbg !2049
  ret void, !dbg !2050
}

With -Zbuild-std, the LLVM IR is as much of a mess as the generated assembly:

Messy IR ```llvm ; test_simd::test_lt_select ; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind nonlazybind willreturn memory(argmem: write) uwtable define internal fastcc void @_ZN9test_simd14test_lt_select17h4bc8c118b05e9d0dE(ptr dead_on_unwind noalias nocapture noundef writable writeonly align 32 dereferenceable(32) %_0, <32 x i8> %idxs.0.val) unnamed_addr #3 !dbg !2383 { start: #dbg_declare(ptr undef, !2385, !DIExpression(), !2386) #dbg_declare(ptr undef, !2387, !DIExpression(), !2391) #dbg_declare(ptr undef, !2393, !DIExpression(), !2398) #dbg_value(<32 x i8> , !2390, !DIExpression(), !2400) %0 = icmp ult <32 x i8> %idxs.0.val, , !dbg !2401 #dbg_value(<32 x i8> poison, !2402, !DIExpression(), !2405) #dbg_value(<32 x i8> poison, !2407, !DIExpression(), !2410) %bc13 = bitcast <32 x i1> %0 to <4 x i8>, !dbg !2412 %.sroa.05.0.extract.trunc = extractelement <4 x i8> %bc13, i64 0, !dbg !2412 %.sroa.3.0.extract.trunc = extractelement <4 x i8> %bc13, i64 2, !dbg !2412 #dbg_value(i8 %.sroa.05.0.extract.trunc, !2396, !DIExpression(DW_OP_LLVM_fragment, 0, 8), !2413) #dbg_value(i8 %.sroa.05.0.extract.trunc, !2414, !DIExpression(DW_OP_LLVM_fragment, 0, 8), !2417) #dbg_value(i8 %.sroa.05.0.extract.trunc, !2419, !DIExpression(DW_OP_LLVM_fragment, 0, 8), !2422) #dbg_value(i8 undef, !2396, !DIExpression(DW_OP_LLVM_fragment, 8, 8), !2413) #dbg_value(i8 undef, !2414, !DIExpression(DW_OP_LLVM_fragment, 8, 8), !2417) #dbg_value(i8 undef, !2419, !DIExpression(DW_OP_LLVM_fragment, 8, 8), !2422) #dbg_value(i8 %.sroa.3.0.extract.trunc, !2396, !DIExpression(DW_OP_LLVM_fragment, 16, 8), !2413) #dbg_value(i8 %.sroa.3.0.extract.trunc, !2414, !DIExpression(DW_OP_LLVM_fragment, 16, 8), !2417) #dbg_value(i8 %.sroa.3.0.extract.trunc, !2419, !DIExpression(DW_OP_LLVM_fragment, 16, 8), !2422) #dbg_value(i8 undef, !2396, !DIExpression(DW_OP_LLVM_fragment, 24, 8), !2413) #dbg_value(i8 undef, !2414, !DIExpression(DW_OP_LLVM_fragment, 24, 8), !2417) #dbg_value(i8 undef, !2419, !DIExpression(DW_OP_LLVM_fragment, 24, 8), !2422) #dbg_value(<32 x i8> , !2397, !DIExpression(), !2413) %.sroa.39.0.insert.ext = zext i8 %.sroa.3.0.extract.trunc to i32, !dbg !2424 %.sroa.39.0.insert.shift = shl nuw nsw i32 %.sroa.39.0.insert.ext, 16, !dbg !2424 %1 = shufflevector <4 x i8> %bc13, <4 x i8> poison, <2 x i32> , !dbg !2424 %2 = zext <2 x i8> %1 to <2 x i32>, !dbg !2424 %3 = shl nuw <2 x i32> %2, , !dbg !2424 %4 = extractelement <2 x i32> %3, i64 0, !dbg !2424 %.sroa.39.0.insert.insert = or disjoint i32 %4, %.sroa.39.0.insert.shift, !dbg !2424 %5 = extractelement <2 x i32> %3, i64 1, !dbg !2424 %.sroa.28.0.insert.insert = or disjoint i32 %.sroa.39.0.insert.insert, %5, !dbg !2424 %.sroa.07.0.insert.ext = zext i8 %.sroa.05.0.extract.trunc to i32, !dbg !2424 %.sroa.07.0.insert.insert = or disjoint i32 %.sroa.28.0.insert.insert, %.sroa.07.0.insert.ext, !dbg !2424 %6 = bitcast i32 %.sroa.07.0.insert.insert to <32 x i1>, !dbg !2424 %7 = select <32 x i1> %6, <32 x i8> %idxs.0.val, <32 x i8> , !dbg !2425 store <32 x i8> %7, ptr %_0, align 32, !dbg !2425 ret void, !dbg !2426 } ```

With -Zbuild-std, but a target-cpu without avx512 (e.g. x86-64-v3), the IR and assembly are beautiful again.

rustc --version --verbose:

rustc 1.82.0-nightly (636d7ff91 2024-08-19)
binary: rustc
commit-hash: 636d7ff91b9847d6d43c7bbe023568828f6e3246
commit-date: 2024-08-19
host: x86_64-unknown-linux-gnu
release: 1.82.0-nightly
LLVM version: 19.1.0

Reproduced with and without lto = "thin", also on Windows, and also with different target-cpu (x86-64-v4, skylake-avx512).

cvijdea-bd commented 3 weeks ago

It can be reduced to a minimal example which doesn't require -Zbuild-std (Godbolt link):

#![feature(core_intrinsics)]
#![feature(portable_simd)]

use std::simd::prelude::*;

#[inline(never)]
#[no_mangle]
pub fn test_lt_select_mask_raw(idxs: u8x32) -> u8x32 {
    unsafe {
        let m: i8x32 = core::intrinsics::simd::simd_lt(idxs, Simd::splat(32u8));
        // changing to `m: u32` here gets the good codegen; using `[u8; 4]` gets the bad codegen
        let m: [u8; 4] = core::intrinsics::simd::simd_bitmask(m);
        let m: i8x32 =
            core::intrinsics::simd::simd_select_bitmask(m, Simd::splat(-1), Simd::splat(0));
        core::intrinsics::simd::simd_select(m, idxs, Simd::splat(u8::MAX))
    }
}

Seems to be caused by https://github.com/rust-lang/portable-simd/blob/master/crates/core_simd/src/masks/bitmask.rs

It's worth noting that AVX2 / SSE codegen is equally bad when using the [u8; 4] simd_bitmask variant. It's just avoided by the fact that the mask::bitmask::Mask implementation is only cfg-ed in on avx512f (in the non-bitmask case it's uses just the simd_lt + simd_select intrinsics).