llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
28.47k stars 11.77k forks source link

[AArch64] memcmp(X, <const>) == 0 is not SIMDifed #53988

Open EgorBo opened 2 years ago

EgorBo commented 2 years ago

TLDR: https://godbolt.org/z/be34K3ba8

While on x64 Expand memcmp() to load/stores (expandmemcmp) transformation pass does a good job:

*** IR Dump Before Expand memcmp() to load/stores (expandmemcmp) ***
; Function Attrs: mustprogress nofree nounwind readonly willreturn uwtable
define dso_local noundef zeroext i1 @_Z12IsLoremIpsumPc(i8* nocapture noundef readonly %0) local_unnamed_addr #0 {
  %2 = tail call i32 @bcmp(i8* noundef nonnull dereferenceable(64) %0, i8* noundef nonnull dereferenceable(64) getelementptr inbounds ([64 x i8], [64 x i8]* @.str, i64 0, i64 0), i64 64)
  %3 = icmp eq i32 %2, 0
  ret i1 %3
}

*** IR Dump After Expand memcmp() to load/stores (expandmemcmp) ***
; Function Attrs: mustprogress nofree nounwind readonly willreturn uwtable
define dso_local noundef zeroext i1 @_Z12IsLoremIpsumPc(i8* nocapture noundef readonly %0) local_unnamed_addr #0 {
  %2 = bitcast i8* %0 to i256*
  %3 = load i256, i256* %2, align 1
  %4 = xor i256 %3, 52211099530468751664686933754856474077863816206753030648151240171596364803916
  %5 = getelementptr i8, i8* %0, i64 32
  %6 = bitcast i8* %5 to i256*
  %7 = load i256, i256* %6, align 1
  %8 = xor i256 %7, 196811072116336222543414263149244389794262714704868596226640400969530368869
  %9 = or i256 %4, %8
  %10 = icmp ne i256 %9, 0
  %11 = zext i1 %10 to i32
  %12 = icmp eq i32 %11, 0
  ret i1 %12
}

which is lowered down to (AVX2):

IsLoremIpsum(char*):                     # @IsLoremIpsum(char*)
        vmovdqu ymm0, ymmword ptr [rdi]
        vmovdqu ymm1, ymmword ptr [rdi + 32]
        vpxor   ymm1, ymm1, ymmword ptr [rip + .LCPI0_0]
        vpxor   ymm0, ymm0, ymmword ptr [rip + .LCPI0_1]
        vpor    ymm0, ymm0, ymm1
        vptest  ymm0, ymm0
        sete    al
        vzeroupper
        ret

on ARM64 it lowers bcmp == 0 to a very long sequence of SWAR/Scalar ops:

*** IR Dump Before Expand memcmp() to load/stores (expandmemcmp) ***
; Function Attrs: mustprogress nofree nounwind readonly willreturn uwtable vscale_range(1,16)
define dso_local noundef i1 @_Z12IsLoremIpsumPc(i8* nocapture noundef readonly %0) local_unnamed_addr #0 {
  %2 = tail call i32 @bcmp(i8* noundef nonnull dereferenceable(64) %0, i8* noundef nonnull dereferenceable(64) getelementptr inbounds ([64 x i8], [64 x i8]* @.str, i64 0, i64 0), i64 64)
  %3 = icmp eq i32 %2, 0
  ret i1 %3
}

*** IR Dump After Expand memcmp() to load/stores (expandmemcmp) ***
; Function Attrs: mustprogress nofree nounwind readonly willreturn uwtable vscale_range(1,16)
define dso_local noundef i1 @_Z12IsLoremIpsumPc(i8* nocapture noundef readonly %0) local_unnamed_addr #0 {
  %2 = bitcast i8* %0 to i64*
  %3 = load i64, i64* %2, align 1
  %4 = xor i64 %3, 8100041059028070220
  %5 = getelementptr i8, i8* %0, i64 8
  %6 = bitcast i8* %5 to i64*
  %7 = load i64, i64* %6, align 1
  %8 = xor i64 %7, 8028914711526208883
  %9 = getelementptr i8, i8* %0, i64 16
  %10 = bitcast i8* %9 to i64*
  %11 = load i64, i64* %10, align 1
  %12 = xor i64 %11, 7881616507232526450
  %13 = getelementptr i8, i8* %0, i64 24
  %14 = bitcast i8* %13 to i64*
  %15 = load i64, i64* %14, align 1
  %16 = xor i64 %15, 8317708033332114533
  %17 = getelementptr i8, i8* %0, i64 32
  %18 = bitcast i8* %17 to i64*
  %19 = load i64, i64* %18, align 1
  %20 = xor i64 %19, 2338060299337491301
  %21 = getelementptr i8, i8* %0, i64 40
  %22 = bitcast i8* %21 to i64*
  %23 = load i64, i64* %22, align 1
  %24 = xor i64 %23, 7594040293371503713
  %25 = getelementptr i8, i8* %0, i64 48
  %26 = bitcast i8* %25 to i64*
  %27 = load i64, i64* %26, align 1
  %28 = xor i64 %27, 3203301149241272174
  %29 = getelementptr i8, i8* %0, i64 56
  %30 = bitcast i8* %29 to i64*
  %31 = load i64, i64* %30, align 1
  %32 = xor i64 %31, 31353812700984096
  %33 = or i64 %4, %8
  %34 = or i64 %12, %16
  %35 = or i64 %20, %24
  %36 = or i64 %28, %32
  %37 = or i64 %33, %34
  %38 = or i64 %35, %36
  %39 = or i64 %37, %38
  %40 = icmp ne i64 %39, 0
  %41 = zext i1 %40 to i32
  %42 = icmp eq i32 %41, 0
  ret i1 %42
}

which is lowered down to:

IsLoremIpsum(char*):                     // @IsLoremIpsum(char*)
        ldp     x9, x11, [x0]
        mov     x8, #28492
        mov     x10, #30067
        mov     x12, #8306
        movk    x8, #25970, lsl #16
        ldp     x13, x15, [x0, #16]
        movk    x10, #8301, lsl #16
        movk    x12, #26995, lsl #16
        mov     x14, #29797
        movk    x8, #8301, lsl #32
        movk    x10, #28516, lsl #32
        movk    x12, #8308, lsl #32
        movk    x14, #8236, lsl #16
        movk    x8, #28777, lsl #48
        movk    x10, #28524, lsl #48
        movk    x12, #28001, lsl #48
        movk    x14, #28515, lsl #32
        eor     x8, x9, x8
        movk    x14, #29550, lsl #48
        eor     x9, x11, x10
        eor     x10, x13, x12
        eor     x12, x15, x14
        ldp     x11, x13, [x0, #32]
        mov     x14, #25445
        mov     x16, #25697
        movk    x14, #25972, lsl #16
        movk    x16, #28777, lsl #16
        movk    x14, #30068, lsl #32
        movk    x16, #29545, lsl #32
        movk    x14, #8306, lsl #48
        movk    x16, #26979, lsl #48
        ldp     x15, x17, [x0, #48]
        eor     x11, x11, x14
        eor     x13, x13, x16
        mov     x14, #26478
        mov     x16, #29472
        movk    x14, #25888, lsl #16
        movk    x16, #25701, lsl #16
        movk    x14, #26988, lsl #32
        movk    x16, #25632, lsl #32
        movk    x14, #11380, lsl #48
        movk    x16, #111, lsl #48
        eor     x14, x15, x14
        eor     x15, x17, x16
        orr     x8, x8, x9
        orr     x9, x10, x12
        orr     x10, x11, x13
        orr     x11, x14, x15
        orr     x8, x8, x9
        orr     x9, x10, x11
        orr     x8, x8, x9
        cmp     x8, #0
        cset    w0, eq
        ret

Hard to imagine an overhead from a case where sequences aren't the same - ARM64 codegen doesn't use branches so it has to accumulate a final 64bit value to figure out it was not the same.

Expected: Neon or SVE2 impl, e.g. something like this for 32byte length: image

llvmbot commented 2 years ago

@llvm/issue-subscribers-backend-aarch64

EgorBo commented 2 years ago

After taking a quick glance the issue, it seems, in enableMemCmpExpansion that unconditionally returns Options.LoadSizes = {8, 4, 2, 1};

https://github.com/llvm/llvm-project/blob/0b41238ae7f91bcc907a577377caa70721ffc400/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp#L1944-L1947

I assume for, at least, -O3 and IsZeroCmp == true it makes sense to consider using 16b. e.g. memset uses neon just fine