llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
28.65k stars 11.84k forks source link

[X86] Attempt to perform AVX1 v8i32 integer comparisons as v8f32 #82242

Closed RKSimon closed 3 months ago

RKSimon commented 8 months ago
__v8si vcmpi(__v8si x) {
    __v8si M = (__v8si) {K,K,K,K,K,K,K,K};
    x &= M;
    return x == M;
}

__v8si vcmpf(__v8si x) {
    __v8si M = (__v8si) {K,K,K,K,K,K,K,K};
    x &= M;
    return __builtin_convertvector(x,__v8sf) == __builtin_convertvector(M, __v8sf);
}
define <8 x i32> @vcmpi(<8 x i32> %x) {
  %and = and <8 x i32> %x, <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
  %cmp = icmp eq <8 x i32> %and, <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
  %sext = sext <8 x i1> %cmp to <8 x i32>
  ret <8 x i32> %sext
}

define <8 x i32> @vcmpf(<8 x i32> %x) {
  %and = and <8 x i32> %x, <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3>
  %conv = sitofp <8 x i32> %and to <8 x float>
  %cmp = fcmp oeq <8 x float> %conv, <float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00>
  %sext = sext <8 x i1> %cmp to <8 x i32>
  ret <8 x i32> %sext
}

AVX1 targets have to split/concat 256-bit integer comparisons, but if we can confirm that the integer values are exactly representable then it can be worth converting to floats and perform as a fp comparison.

llvm-mca: https://llvm.godbolt.org/z/5Tns64zGP

This came up on some sincos implementations where most of the code is 256-bit fp, but we were having to convert to int to compute the quadrant (0 - 3) special cases.

Invert of #82241

llvmbot commented 8 months ago

@llvm/issue-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

```cpp __v8si vcmpi(__v8si x) { __v8si M = (__v8si) {K,K,K,K,K,K,K,K}; x &= M; return x == M; } __v8si vcmpf(__v8si x) { __v8si M = (__v8si) {K,K,K,K,K,K,K,K}; x &= M; return __builtin_convertvector(x,__v8sf) == __builtin_convertvector(M, __v8sf); } ``` ```ll define <8 x i32> @vcmpi(<8 x i32> %x) { %and = and <8 x i32> %x, <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3> %cmp = icmp eq <8 x i32> %and, <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3> %sext = sext <8 x i1> %cmp to <8 x i32> ret <8 x i32> %sext } define <8 x i32> @vcmpf(<8 x i32> %x) { %and = and <8 x i32> %x, <i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3, i32 3> %conv = sitofp <8 x i32> %and to <8 x float> %cmp = fcmp oeq <8 x float> %conv, <float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00> %sext = sext <8 x i1> %cmp to <8 x i32> ret <8 x i32> %sext } ``` AVX1 targets have to split/concat 256-bit integer comparisons, but if we can confirm that the integer values are exactly representable as integers then it can be worth converting to floats and perform as a fp comparison. llvm-mca: https://llvm.godbolt.org/z/5Tns64zGP This came up on some sincos implementations where most of the code is 256-bit fp, but we were having to convert to int to compute the quadrant (0 - 3) special cases. Invert of #82241
goldsteinn commented 8 months ago

I think the constraint is a bit stronger than "the integer value are exactly representable...":

define i1 @src(i32 %x, i32 %C) {
;  %Cfp = sitof
;  %C_lemma = call i1 @llvm.is.fpclass.f32(float %C, i32 328)
;  call void @llvm.assume(i1 %C_lemma)
;  %CI32 = bitcast float %C to i32
  %C_lemma = icmp ult i32 %C, 2147483648
  call void @llvm.assume(i1 %C_lemma)

  %CFp = sitofp i32 %C to float
  %CI32 = bitcast float %CFp to i32
  %eq = icmp eq i32 %C, %CI32
  call void @llvm.assume(i1 %eq)

  %cmp = icmp eq i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt(i32 %x, i32 %C) {
;  %C_lemma = call i1 @llvm.is.fpclass.f32(float %C, i32 328)
;  call void @llvm.assume(i1 %C_lemma)
  %C_lemma = icmp ult i32 %C, 2147483648
  call void @llvm.assume(i1 %C_lemma)

  %CFp = sitofp i32 %C to float
  %CI32 = bitcast float %CFp to i32
  %eq = icmp eq i32 %C, %CI32
  call void @llvm.assume(i1 %eq)

  %conv = sitofp i32 %x to float
  %cmp = fcmp oeq float %conv, %CFp
  ret i1 %cmp
}

Fails (timeout online).

OTOH 2147483648 -> 1073741824 passes. Trying to work out the exact constraint.

goldsteinn commented 8 months ago

The counter example is:

i32 %x = #x4e9d3aa8 (1318927016)
i32 %C = #x4e9d3a75 (1318926965)

but

define i1 @foo() {
  %lhs = bitcast i32 u0x4e9d3aa8 to float
  %rhs = bitcast i32 u0x4e9d3a75 to float
  %cmp = fcmp oeq float %lhs, %rhs
  ret i1 %cmp
}

evaluates to false. Maybe an alive2 bug?

goldsteinn commented 8 months ago

The following seem to be the best bounds I can for i32 -> float:

define i1 @src_eq(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %cmp = icmp eq i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt_eq(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %CFp = sitofp i32 %C to float
  %conv = sitofp i32 %x to float
  %cmp = fcmp oeq float %conv, %CFp
  ret i1 %cmp
}

define i1 @src_ne(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %cmp = icmp ne i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt_ne(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %CFp = sitofp i32 %C to float
  %conv = sitofp i32 %x to float
  %cmp = fcmp one float %conv, %CFp
  ret i1 %cmp
}

define i1 @src_slt(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %cmp = icmp slt i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt_slt(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %CFp = sitofp i32 %C to float
  %conv = sitofp i32 %x to float
  %cmp = fcmp olt float %conv, %CFp
  ret i1 %cmp
}

define i1 @src_sgt(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %cmp = icmp sgt i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt_sgt(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %CFp = sitofp i32 %C to float
  %conv = sitofp i32 %x to float
  %cmp = fcmp ogt float %conv, %CFp
  ret i1 %cmp
}

define i1 @src_sle(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %cmp = icmp sle i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt_sle(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %CFp = sitofp i32 %C to float
  %conv = sitofp i32 %x to float
  %cmp = fcmp ole float %conv, %CFp
  ret i1 %cmp
}

define i1 @src_sge(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %cmp = icmp sge i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt_sge(i32 %x, i32 %C) {
  %C_abs = call i32 @llvm.abs.i32(i32 %C, i1 false)
  %X_abs = call i32 @llvm.abs.i32(i32 %x, i1 false)
  %C_lemma = icmp ult i32 %C_abs, 16777216
  %X_lemma = icmp ult i32 %X_abs, 16777216
  %lemma = or i1 %C_lemma, %X_lemma
  call void @llvm.assume(i1 %lemma)

  %CFp = sitofp i32 %C to float
  %conv = sitofp i32 %x to float
  %cmp = fcmp oge float %conv, %CFp
  ret i1 %cmp
}

define i1 @src_ult(i32 %x, i32 %C) {
  %C_lemma0 = icmp ult i32 %C, 16777216
  %X_lemma0 = icmp ult i32 %x, 16777216
  %lemma0 = or i1 %C_lemma0, %X_lemma0
  call void @llvm.assume(i1 %lemma0)

  %C_lemma1 = icmp sge i32 %C, 0
  %X_lemma1 = icmp sge i32 %x, 0
  %lemma1 = and i1 %C_lemma1, %X_lemma1
  call void @llvm.assume(i1 %lemma1)

  %cmp = icmp ult i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt_ult(i32 %x, i32 %C) {
  %C_lemma0 = icmp ult i32 %C, 16777216
  %X_lemma0 = icmp ult i32 %x, 16777216
  %lemma0 = or i1 %C_lemma0, %X_lemma0
  call void @llvm.assume(i1 %lemma0)

  %C_lemma1 = icmp sge i32 %C, 0
  %X_lemma1 = icmp sge i32 %x, 0
  %lemma1 = and i1 %C_lemma1, %X_lemma1
  call void @llvm.assume(i1 %lemma1)

  %CFp = sitofp i32 %C to float
  %conv = sitofp i32 %x to float
  %cmp = fcmp olt float %conv, %CFp
  ret i1 %cmp
}

define i1 @src_ule(i32 %x, i32 %C) {
  %C_lemma0 = icmp ult i32 %C, 16777216
  %X_lemma0 = icmp ult i32 %x, 16777216
  %lemma0 = or i1 %C_lemma0, %X_lemma0
  call void @llvm.assume(i1 %lemma0)

  %C_lemma1 = icmp sge i32 %C, 0
  %X_lemma1 = icmp sge i32 %x, 0
  %lemma1 = and i1 %C_lemma1, %X_lemma1
  call void @llvm.assume(i1 %lemma1)

  %cmp = icmp ule i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt_ule(i32 %x, i32 %C) {
  %C_lemma0 = icmp ult i32 %C, 16777216
  %X_lemma0 = icmp ult i32 %x, 16777216
  %lemma0 = or i1 %C_lemma0, %X_lemma0
  call void @llvm.assume(i1 %lemma0)

  %C_lemma1 = icmp sge i32 %C, 0
  %X_lemma1 = icmp sge i32 %x, 0
  %lemma1 = and i1 %C_lemma1, %X_lemma1
  call void @llvm.assume(i1 %lemma1)

  %CFp = sitofp i32 %C to float
  %conv = sitofp i32 %x to float
  %cmp = fcmp ole float %conv, %CFp
  ret i1 %cmp
}

define i1 @src_ugt(i32 %x, i32 %C) {
  %C_lemma0 = icmp ult i32 %C, 16777216
  %X_lemma0 = icmp ult i32 %x, 16777216
  %lemma0 = or i1 %C_lemma0, %X_lemma0
  call void @llvm.assume(i1 %lemma0)

  %C_lemma1 = icmp sge i32 %C, 0
  %X_lemma1 = icmp sge i32 %x, 0
  %lemma1 = and i1 %C_lemma1, %X_lemma1
  call void @llvm.assume(i1 %lemma1)

  %cmp = icmp ugt i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt_ugt(i32 %x, i32 %C) {
  %C_lemma0 = icmp ult i32 %C, 16777216
  %X_lemma0 = icmp ult i32 %x, 16777216
  %lemma0 = or i1 %C_lemma0, %X_lemma0
  call void @llvm.assume(i1 %lemma0)

  %C_lemma1 = icmp sge i32 %C, 0
  %X_lemma1 = icmp sge i32 %x, 0
  %lemma1 = and i1 %C_lemma1, %X_lemma1
  call void @llvm.assume(i1 %lemma1)

  %CFp = sitofp i32 %C to float
  %conv = sitofp i32 %x to float
  %cmp = fcmp ogt float %conv, %CFp
  ret i1 %cmp
}

define i1 @src_uge(i32 %x, i32 %C) {
  %C_lemma0 = icmp ult i32 %C, 16777216
  %X_lemma0 = icmp ult i32 %x, 16777216
  %lemma0 = or i1 %C_lemma0, %X_lemma0
  call void @llvm.assume(i1 %lemma0)

  %C_lemma1 = icmp sge i32 %C, 0
  %X_lemma1 = icmp sge i32 %x, 0
  %lemma1 = and i1 %C_lemma1, %X_lemma1
  call void @llvm.assume(i1 %lemma1)

  %cmp = icmp uge i32 %x, %C
  ret i1 %cmp
}

define i1 @tgt_uge(i32 %x, i32 %C) {
  %C_lemma0 = icmp ult i32 %C, 16777216
  %X_lemma0 = icmp ult i32 %x, 16777216
  %lemma0 = or i1 %C_lemma0, %X_lemma0
  call void @llvm.assume(i1 %lemma0)

  %C_lemma1 = icmp sge i32 %C, 0
  %X_lemma1 = icmp sge i32 %x, 0
  %lemma1 = and i1 %C_lemma1, %X_lemma1
  call void @llvm.assume(i1 %lemma1)

  %CFp = sitofp i32 %C to float
  %conv = sitofp i32 %x to float
  %cmp = fcmp oge float %conv, %CFp
  ret i1 %cmp
}
goldsteinn commented 8 months ago

Have a preliminary patch, noticed an issue with cosntant folding of casted floats so going to post patch for that first

anematode commented 8 months ago

@goldsteinn I might be wrong, but I think the bitcast method needs to be enabled only if you can guarantee the original integers are between 1 << 23 (smallest normal number) and 0x7f800000. Because if the DAZ bit is set (say due to -ffast-math) then you'll get (int)1 == (int)2, since you're comparing two denormals and they are treated as zero.

In the case of checking for inequality, I think you can do this for f32 (intel syntax):

vxorps ymm0, a, b ; a 32-bit word is zero iff the elements are bitwise equal
vcvtdq2ps ymm1, ymm0 ; treat as int and convert to float
vcmpps ymm0, ymm1, zero, _CMP_EQ_OQ ; will only equal zero if the original elements are bitwise equal

No denormals there. I don't think you can do anything similar with f64 though.

E.g., this program has different comparison outputs depending on whether DAZ is set: https://godbolt.org/z/an9rYbj45

RKSimon commented 4 months ago

Its proving tricky to handle the regressions on #82290 - I'll see if I can devise a more moderate alternative / first step