llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
29.07k stars 11.99k forks source link

[DAG] Failure to fold select(x, sub(x, c), m) -> sub(x, and(c,m)) #66101

Closed RKSimon closed 8 months ago

RKSimon commented 1 year ago

https://godbolt.org/z/a1PczEM8a

If we're selecting a subtraction with a non-constant we fold the select into an and:

#include <x86intrin.h>
auto masked_select(__m128i a, __m128i b, __m128i x, __m128i y) {
    return _mm_blendv_epi8(a, _mm_sub_epi32(a, b), _mm_cmpgt_epi32(x,y));
}
masked_select(long long __vector(2), long long __vector(2), long long __vector(2), long long __vector(2)): # @masked_select(long long __vector(2), long long __vector(2), long long __vector(2), long long __vector(2))
  pcmpgtd %xmm3, %xmm2
  pand %xmm1, %xmm2
  psubd %xmm2, %xmm0
  retq

But for constants this fails, which on x86 can result in a BLENDV instruction, which is never faster than an AND

#include <x86intrin.h>
auto masked_select_const(__m128i a, __m128i x, __m128i y) {
    __m128i b = _mm_set1_epi32(24);
    return _mm_blendv_epi8(a, _mm_sub_epi32(a, b), _mm_cmpgt_epi32(x,y));
}
masked_select_const(long long __vector(2), long long __vector(2), long long __vector(2)): # @masked_select_const(long long __vector(2), long long __vector(2), long long __vector(2))
  movdqa %xmm0, %xmm3
  movdqa .LCPI3_0(%rip), %xmm4 # xmm4 = [4294967272,4294967272,4294967272,4294967272]
  paddd %xmm0, %xmm4
  pcmpgtd %xmm2, %xmm1
  movdqa %xmm1, %xmm0
  blendvps %xmm0, %xmm4, %xmm3
  movaps %xmm3, %xmm0
  retq
llvmbot commented 1 year ago

@llvm/issue-subscribers-backend-x86

https://godbolt.org/z/a1PczEM8a If we're selecting a subtracting a non-constant we fold the select into a and: ```c #include auto masked_select(__m128i a, __m128i b, __m128i x, __m128i y) { return _mm_blendv_epi8(a, _mm_sub_epi32(a, b), _mm_cmpgt_epi32(x,y)); } ``` ```asm masked_select(long long __vector(2), long long __vector(2), long long __vector(2), long long __vector(2)): # @masked_select(long long __vector(2), long long __vector(2), long long __vector(2), long long __vector(2)) pcmpgtd %xmm3, %xmm2 pand %xmm1, %xmm2 psubd %xmm2, %xmm0 retq ``` But for constants this fails, which on x86 can result in a BLENDV instruction, which is never faster than a AND ```c #include auto masked_select_const(__m128i a, __m128i x, __m128i y) { __m128i b = _mm_set1_epi32(24); return _mm_blendv_epi8(a, _mm_sub_epi32(a, b), _mm_cmpgt_epi32(x,y)); } ``` ```asm masked_select_const(long long __vector(2), long long __vector(2), long long __vector(2)): # @masked_select_const(long long __vector(2), long long __vector(2), long long __vector(2)) movdqa %xmm0, %xmm3 movdqa .LCPI3_0(%rip), %xmm4 # xmm4 = [4294967272,4294967272,4294967272,4294967272] paddd %xmm0, %xmm4 pcmpgtd %xmm2, %xmm1 movdqa %xmm1, %xmm0 blendvps %xmm0, %xmm4, %xmm3 movaps %xmm3, %xmm0 retq ```
RKSimon commented 1 year ago

CC @elhewaty

elhewaty commented 1 year ago

assign me, please.

elhewaty commented 1 year ago

@RKSimon Is there any source I can use to understand DAG internals.

RKSimon commented 1 year ago

I'd start by seeing whats the difference between the IR being fed to DAG from masked_select vs masked_select_const - you will probably need to remove a lot of unnecessary bitcasts. Then step through the DAGCombine stages of running llc in a debugger - add breakpoints to the start of visitADD/visitSUB/visitVSELECT and see whats happening.

You can also use "llc --debug" (using a debug assertion build) to dump out everything llc has done: https://rust.godbolt.org/z/szYv5G8n9

elhewaty commented 9 months ago

Hello @RKSimon.

// select X, sub(X, C), m --> sub (X, and(C, m))
  if (N1.getOpcode() == ISD::SUB && N1.getOperand(0) == N0 && N1.hasOneUse()) {
    if (dyn_cast<ConstantSDNode>(N1.getOperand(1)))
      return DAG.getNode(ISD::SUB, DL, N0.getValueType(), N0,
                         DAG.getNode(ISD::AND, DL, N2.getValueType(),
                                     N1.getOperand(1), N2));
  }

Here's what reached so far, I tried to match a pattern in visitSELECT function. is this logic correct?

RKSimon commented 9 months ago

Yes, that looks about right - you should use isConstantIntBuildVectorOrConstantInt instead of dyn_cast<ConstantSDNode> so it can match vector constant as well

RKSimon commented 9 months ago

Also, you need to sort out argument order (sorry when I reported this I was thinking _mm_blendv_epi8 order not select IR order)

elhewaty commented 9 months ago

@RKSimon, I used the following test case:

define <2 x i64> @masked_select_const(<2 x i64> %a, <2 x i64> %x, <2 x i64> %y) {
  %bit_a = bitcast <2 x i64> %a to <4 x i32>
  %sub.i = add <4 x i32> %bit_a, <i32 -24, i32 -24, i32 -24, i32 -24>
  %bit_x = bitcast <2 x i64> %x to <4 x i32>
  %bit_y = bitcast <2 x i64> %y to <4 x i32>
  %cmp.i = icmp sgt <4 x i32> %bit_x, %bit_y
  %sel = select <4 x i1> %cmp.i, <4 x i32> %sub.i, <4 x i32> %bit_a
  %bit_sel = bitcast <4 x i32> %sel to <2 x i64>
  ret <2 x i64> %bit_sel
}

The following code can't match the select

// select m, sub(X, C), X --> sub (X, and(C, m))
  if (N1.getOpcode() == ISD::SUB && N1.getOperand(0) == N2 && N1->hasOneUse() &&
      DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1))) {
    return DAG.getNode(ISD::SUB, DL, N1.getValueType(), N2,
                       DAG.getNode(ISD::AND, DL, N0.getValueType(), N1.getOperand(1),
                                   N0));
  }

Any hint?

elhewaty commented 9 months ago

@RKSimon ping

RKSimon commented 8 months ago

Sorry I missed your ping.

In many cases DAG will try to fold (sub x, c) -> (add x, -c) so you will need to do this in terms of ADD:

  // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
  if (N1.getOpcode() == ISD::ADD && N1.getOperand(0) == N2 && N1->hasOneUse() &&
      DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1)) && 
      N0.getScalarValueSizeInBits() == N1.getScalarValueSizeInBits()) {
    return DAG.getNode(ISD::ADD, DL, N1.getValueType(), N2,
                       DAG.getNode(ISD::AND, DL, N0.getValueType(), N1.getOperand(1),
                                   N0));
  }

Note you need to ensure the N0 condition is the same width as the True/False operands otherwise you might affect targets with predicate mask types (AVX512 etc).

RKSimon commented 8 months ago

@elhewaty Do you have a PR (draft or active) anywhere with your work so far?