llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
28.71k stars 11.88k forks source link

`-1 << (k + 1)` => `-2 << k` Inverted shift optimizations should incorporate offsets #102946

Open Validark opened 2 months ago

Validark commented 2 months ago

Let's say I have this function:

export fn foo(m: u64) u64 {
    const x = ~m & (m << 1);
    return (@as(u64, 1) << @intCast(@popCount(x) + 1)) - 1;
}

LLVM handily optimizes it to the equivalent of this:

export fn foo(m: u64) u64 {
    const x = ~m & (m << 1);
    return ~(~@as(u64, 0) << @intCast(@popCount(x) + 1));
}

However, we can do slightly better by moving the +1 into ~@as(u64, 0) by pre-shifting it by 1. Since ~@as(u64, 0) << 1 is ~@as(u64, 1), we get:

export fn bar(m: u64) u64 {
    const x = ~m & (m << 1);
    return ~(~@as(u64, 1) << @intCast(@popCount(x)));
}

Here is the assembly version:

foo:
        lea     rax, [rdi + rdi]
        mov     rcx, -1
        andn    rax, rdi, rax
        popcnt  rax, rax
        inc     al ; we can remove this increment by changing `rcx` to -2
        shlx    rax, rcx, rax
        not     rax
        ret

bar:
        lea     rax, [rdi + rdi]
        mov     rcx, -2
        andn    rax, rdi, rax
        popcnt  rax, rax
        shlx    rax, rcx, rax
        not     rax
        ret

This optimization should be applicable across architectures. (Godbolt link)

dtcxzyw commented 2 months ago

@Validark Can you explain what @foo is used for?

Validark commented 2 months ago

I wrote this code because I wanted to produce a mask with the same number of bits as in x, but all concentrated in the front.

I was originally going to feed this, perhaps with some transformation, into a pdep instruction, to turn off certain bits of x, but I ended up needing to preserve some other bits that made me shift strategies.

I think the idea is pretty general-purpose still though. The idea is:

-1 << (k + 1) => -2 << k

I think this would benefit code in the wild.

dtcxzyw commented 2 months ago

-1 << (k + 1) => -2 << k

LLVM doesn't fold this pattern since it cannot prove k + 1 < bitwidth (i.e., @popCount(~m & (m << 1)) < bitwidth - 1) : https://alive2.llvm.org/ce/z/FfYEmg

dtcxzyw commented 2 months ago

The pattern @popCount(~m & (m << 1)) does not seem to exist in the real-world code :(

Validark commented 2 months ago

-1 << (k + 1) => -2 << k

LLVM doesn't fold this pattern since it cannot prove k + 1 < bitwidth (i.e., @popCount(~m & (m << 1)) < bitwidth - 1) : https://alive2.llvm.org/ce/z/FfYEmg

Did you realize you were using i16's? For i16, the maximum popCount of m & ~(m << 1), ~m & (m << 1), and m & (~m << 1), and versions with << turned to >>, is 8. And LLVM is able to prove this: https://alive2.llvm.org/ce/z/oNtS3G

The pattern @popCount(~m & (m << 1)) does not seem to exist in the real-world code :(

The significance of the pattern is just that all patterns that look like it by definition have a popcount of up to half of the bitstring length. I don't think that this specific pattern is what you should search for, it's just one example.

Also, even setting aside the pattern, when I have -1 << @intCast(k + 1), I am guaranteeing k + 1 is in the range [0, BIT_LEN), or rather, [1, BIT_LEN) if we look at that +1. Hence the @intCast() call which is asserting that it is a valid u6 for shifting a u64 bitstring. Therefore this transformation should be safe. (Note that overflow is guaranteed not to occur, both by the possible range from the popcount and the fact that popCount returns a u7, which is meant to hold the maximum popcount of 64, therefore adding 1 would not overflow a u7, and my intCast asserts the resulting value will be 63 or less. Therefore, we can safely shift the range of k + 1 => [1, 63] to k => [0, 62] to do this optimization, even without the knowledge of the popcount being at most 32 for i64's)

dtcxzyw commented 2 months ago

And LLVM is able to prove this: https://alive2.llvm.org/ce/z/oNtS3G

It is proven by Z3, not LLVM.

I don't think that this specific pattern is what you should search for, it's just one example.

As I said before, we can address this issue by proving @popCount(~m & (m << 1)) < bitwidth - 1. But this pattern doesn't exist in real-world code. So I think it is not interesting.

nikic commented 2 months ago

Hence the @intCast() call which is asserting that it is a valid u6 for shifting a u64 bitstring.

FWIW, this is not at all obvious for anyone who is not very familiar with zig. It's also not clear how zig is conveying this information to LLVM, if it is doing so at all.

Please, can you avoid using Zig when reporting LLVM issues, if possible? If you can't write LLVM IR directly, any of C, C++ and Rust is fine to use, as these are languages that are easy to understand for LLVM developers, and we know how to interact with the frontend to obtain LLVM IR at different optimization stages.

As I said before, we can address this issue by proving @popCount(~m & (m << 1)) < bitwidth - 1. But this pattern doesn't exist in real-world code. So I think it is not interesting.

I agree.

nikic commented 2 months ago

I think this issue would be solved if a) zig lowers the @intCast to trunc nuw and b) LLVM elides the zext (trunc nuw) instead of generating an and? Without the and the shl + add nuw pattern should fold fine.

b) will be fixed when https://github.com/llvm/llvm-project/pull/88609 lands, I don't know what zig does about a) as I don't know how to get the pre-optimization IR for this compiler.

Validark commented 2 months ago

Hence the @intCast() call which is asserting that it is a valid u6 for shifting a u64 bitstring.

FWIW, this is not at all obvious for anyone who is not very familiar with zig. It's also not clear how zig is conveying this information to LLVM, if it is doing so at all.

Please, can you avoid using Zig when reporting LLVM issues, if possible? If you can't write LLVM IR directly, any of C, C++ and Rust is fine to use, as these are languages that are easy to understand for LLVM developers, and we know how to interact with the frontend to obtain LLVM IR at different optimization stages.

As I said before, we can address this issue by proving @popCount(~m & (m << 1)) < bitwidth - 1. But this pattern doesn't exist in real-world code. So I think it is not interesting.

I agree.

You can get Zig to emit the LLVM IR directly like so:

zig build-obj ./src/llvm_code.zig -O ReleaseFast -target x86_64-linux -mcpu znver4 -femit-llvm-ir -fstrip

Gives:

; ModuleID = 'BitcodeBuffer'
source_filename = "llvm_code"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-musl"

; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none) uwtable
define dso_local i64 @foo(i64 %0) local_unnamed_addr #0 {
  %2 = xor i64 %0, -1
  %3 = shl i64 %0, 1
  %4 = and i64 %3, %2
  %5 = tail call i64 @llvm.ctpop.i64(i64 %4), !range !0
  %6 = add nuw nsw i64 %5, 1
  %7 = and i64 %6, 63
  %notmask = shl nsw i64 -1, %7
  %8 = xor i64 %notmask, -1
  ret i64 %8
}

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i64 @llvm.ctpop.i64(i64) #1

; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none) uwtable
define dso_local i64 @bar(i64 %0) local_unnamed_addr #0 {
  %2 = xor i64 %0, -1
  %3 = shl i64 %0, 1
  %4 = and i64 %3, %2
  %5 = tail call i64 @llvm.ctpop.i64(i64 %4), !range !0
  %6 = shl i64 -2, %5
  %7 = xor i64 %6, -1
  ret i64 %7
}

; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none) uwtable
define dso_local i64 @food(i64 %0) local_unnamed_addr #0 {
  %2 = tail call i64 @foo(i64 %0) #0
  ret i64 %2
}

attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn memory(none) uwtable "frame-pointer"="none" "target-cpu"="znver4" "target-features"="-16bit-mode,-32bit-mode,-3dnow,-3dnowa,+64bit,+adx,+aes,+allow-light-256-bit,-amx-bf16,-amx-complex,-amx-fp16,-amx-int8,-amx-tile,+avx,-avx10.1-256,-avx10.1-512,+avx2,+avx512bf16,+avx512bitalg,+avx512bw,+avx512cd,+avx512dq,-avx512er,+avx512f,-avx512fp16,+avx512ifma,-avx512pf,+avx512vbmi,+avx512vbmi2,+avx512vl,+avx512vnni,-avx512vp2intersect,+avx512vpopcntdq,-avxifma,-avxneconvert,-avxvnni,-avxvnniint16,-avxvnniint8,+bmi,+bmi2,+branchfusion,-ccmp,-cf,-cldemote,+clflushopt,+clwb,+clzero,+cmov,-cmpccxadd,+crc32,+cx16,+cx8,-egpr,-enqcmd,-ermsb,+evex512,+f16c,-false-deps-getmant,-false-deps-lzcnt-tzcnt,-false-deps-mulc,-false-deps-mullq,-false-deps-perm,-false-deps-popcnt,-false-deps-range,-fast-11bytenop,+fast-15bytenop,-fast-7bytenop,+fast-bextr,-fast-gather,-fast-hops,+fast-lzcnt,+fast-movbe,+fast-scalar-fsqrt,+fast-scalar-shift-masks,-fast-shld-rotate,-fast-variable-crosslane-shuffle,+fast-variable-perlane-shuffle,+fast-vector-fsqrt,-fast-vector-shift-masks,-faster-shift-than-shuffle,+fma,-fma4,+fsgsbase,+fsrm,+fxsr,+gfni,-harden-sls-ijmp,-harden-sls-ret,-hreset,-idivl-to-divb,-idivq-to-divl,+invpcid,-kl,-lea-sp,-lea-uses-ag,-lvi-cfi,-lvi-load-hardening,-lwp,+lzcnt,+macrofusion,+mmx,+movbe,-movdir64b,-movdiri,+mwaitx,-ndd,-no-bypass-delay,-no-bypass-delay-blend,-no-bypass-delay-mov,-no-bypass-delay-shuffle,+nopl,-pad-short-functions,+pclmul,-pconfig,+pku,+popcnt,-ppx,-prefer-128-bit,-prefer-256-bit,-prefer-mask-registers,-prefer-movmsk-over-vtest,-prefer-no-gather,-prefer-no-scatter,-prefetchi,-prefetchwt1,+prfchw,-ptwrite,-push2pop2,-raoint,+rdpid,+rdpru,+rdrnd,+rdseed,-retpoline,-retpoline-external-thunk,-retpoline-indirect-branches,-retpoline-indirect-calls,-rtm,+sahf,+sbb-dep-breaking,-serialize,-seses,-sgx,+sha,-sha512,+shstk,-slow-3ops-lea,-slow-incdec,-slow-lea,-slow-pmaddwd,-slow-pmulld,+slow-shld,-slow-two-mem-ops,-slow-unaligned-mem-16,-slow-unaligned-mem-32,-sm3,-sm4,-soft-float,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+sse4a,-sse-unaligned-mem,+ssse3,-tagged-globals,-tbm,-tsxldtrk,-tuning-fast-imm-vector-shift,-uintr,-use-glm-div-sqrt-costs,-use-slm-arith-costs,-usermsr,+vaes,+vpclmulqdq,+vzeroupper,-waitpkg,+wbnoinvd,-widekl,+x87,-xop,+xsave,+xsavec,+xsaveopt,+xsaves" }
attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) }

!0 = !{i64 0, i64 64}

LLVM Godbolt link