[Isel x86] i320+ generates useless extra instruction when chaining icmp and select based on underflow #103841

Open mratsim opened 4 weeks ago

mratsim commented 4 weeks ago

This is an alternative implementation of LLVM modular addition from https://github.com/llvm/llvm-project/issues/103717 that uses raw LLVM IR instead of the builtin llvm.usub.with.overflow.iXXX

The code for i256 is optimal but not for i320 or i384 (similar to the previous issue, there seems to be a size threshold after which LLVM gives up removing redundant instructions).


Full code

Original IR

; ModuleID = 'x86_poc'
; target triple = "arm64"
target triple = "x86_64"

@bn254_snarks_fp_mod = constant i256 21888242871839275222246405745257275088696311157297823662689037894645226208583, section "ctt.bn254_snarks_fp.constants", align 64
@bls12_381_fp_mod = constant i384 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787, section "ctt.bls12_381_fp.constants", align 64
@bls24_317_fp_mod = constant i320 136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051, section "ctt.bls24_317_fp.constants", align 64

; Function Attrs: hot
define internal fastcc void @_modadd_noo.u64x4(ptr %0, ptr %1, ptr %2, ptr %3) #2 section "ctt.fields" {
  %a = load i256, ptr %1, align 4
  %b = load i256, ptr %2, align 4
  %M = load i256, ptr %3, align 4
  %a_plus_b = add i256 %a, %b
  %5 = sub i256 %a_plus_b, %M
  %6 = lshr i256 %5, 255
  %7 = trunc i256 %6 to i1
  %8 = select i1 %7, i256 %a_plus_b, i256 %5
  store i256 %8, ptr %0, align 4
  ret void

; Function Attrs: hot
define internal fastcc void @_modadd_noo.u64x5(ptr %0, ptr %1, ptr %2, ptr %3) #2 section "ctt.fields" {
  %a = load i320, ptr %1, align 4
  %b = load i320, ptr %2, align 4
  %M = load i320, ptr %3, align 4
  %a_plus_b = add i320 %a, %b
  %5 = sub i320 %a_plus_b, %M
  %6 = lshr i320 %5, 319
  %7 = trunc i320 %6 to i1
  %8 = select i1 %7, i320 %a_plus_b, i320 %5
  store i320 %8, ptr %0, align 4
  ret void

; Function Attrs: hot
define internal fastcc void @_modadd_noo.u64x6(ptr %0, ptr %1, ptr %2, ptr %3) #2 section "ctt.fields" {
  %a = load i384, ptr %1, align 4
  %b = load i384, ptr %2, align 4
  %M = load i384, ptr %3, align 4
  %a_plus_b = add i384 %a, %b
  %5 = sub i384 %a_plus_b, %M
  %6 = lshr i384 %5, 383
  %7 = trunc i384 %6 to i1
  %8 = select i1 %7, i384 %a_plus_b, i384 %5
  store i384 %8, ptr %0, align 4
  ret void

; Function Attrs: hot
define void @bn254_snarks_fp_add(ptr %0, ptr %1, ptr %2) #2 section "ctt.bn254_snarks_fp" {
  call fastcc void @_modadd_noo.u64x4(ptr %0, ptr %1, ptr %2, ptr @bn254_snarks_fp_mod)
  ret void

; Function Attrs: hot
define void @bls24_317_fp_add(ptr %0, ptr %1, ptr %2) #2 section "ctt.bls24_317_fp" {
  call fastcc void @_modadd_noo.u64x5(ptr %0, ptr %1, ptr %2, ptr @bls24_317_fp_mod)
  ret void

; Function Attrs: hot
define void @bls12_381_fp_add(ptr %0, ptr %1, ptr %2) #2 section "ctt.bls12_381_fp" {
  call fastcc void @_modadd_noo.u64x6(ptr %0, ptr %1, ptr %2, ptr @bls12_381_fp_mod)
  ret void

attributes #2 = { hot }

After opt -O3

; target triple = "arm64"
target triple = "x86_64"

define void @bn254_snarks_fp_add(ptr nocapture writeonly %0, ptr nocapture readonly %1, ptr nocapture readonly %2) local_unnamed_addr #0 section "ctt.bn254_snarks_fp" {
  %.val = load i256, ptr %1, align 4
  %.val1 = load i256, ptr %2, align 4
  %a_plus_b.i = add i256 %.val1, %.val
  %4 = add i256 %a_plus_b.i, -21888242871839275222246405745257275088696311157297823662689037894645226208583
  %.not1.i = icmp slt i256 %4, 0
  %5 = select i1 %.not1.i, i256 %a_plus_b.i, i256 %4
  store i256 %5, ptr %0, align 4
  ret void

define void @bls24_317_fp_add(ptr nocapture writeonly %0, ptr nocapture readonly %1, ptr nocapture readonly %2) local_unnamed_addr #0 section "ctt.bls24_317_fp" {
  %.val = load i320, ptr %1, align 4
  %.val1 = load i320, ptr %2, align 4
  %a_plus_b.i = add i320 %.val1, %.val
  %4 = add i320 %a_plus_b.i, -136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051
  %.not1.i = icmp slt i320 %4, 0
  %5 = select i1 %.not1.i, i320 %a_plus_b.i, i320 %4
  store i320 %5, ptr %0, align 4
  ret void

define void @bls12_381_fp_add(ptr nocapture writeonly %0, ptr nocapture readonly %1, ptr nocapture readonly %2) local_unnamed_addr #0 section "ctt.bls12_381_fp" {
  %.val = load i384, ptr %1, align 4
  %.val1 = load i384, ptr %2, align 4
  %a_plus_b.i = add i384 %.val1, %.val
  %4 = add i384 %a_plus_b.i, -4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787
  %.not1.i = icmp slt i384 %4, 0
  %5 = select i1 %.not1.i, i384 %a_plus_b.i, i384 %4
  store i384 %5, ptr %0, align 4
  ret void

attributes #0 = { hot mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite) }


bn254_snarks_fp_add:                    # @bn254_snarks_fp_add
        mov     rax, qword ptr [rdx + 24]
        mov     rcx, qword ptr [rdx + 16]
        mov     r8, qword ptr [rdx]
        mov     rdx, qword ptr [rdx + 8]
        add     r8, qword ptr [rsi]
        adc     rdx, qword ptr [rsi + 8]
        adc     rcx, qword ptr [rsi + 16]
        adc     rax, qword ptr [rsi + 24]
        movabs  rsi, -4332616871279656263
        add     rsi, r8
        movabs  r9, 7529619929231668594
        adc     r9, rdx
        movabs  r10, 5165552122434856866
        adc     r10, rcx
        movabs  r11, -3486998266802970666
        adc     r11, rax
        cmovs   r9, rdx
        cmovs   r11, rax
        cmovs   r10, rcx
        cmovs   rsi, r8
        mov     qword ptr [rdi + 16], r10
        mov     qword ptr [rdi + 24], r11
        mov     qword ptr [rdi], rsi
        mov     qword ptr [rdi + 8], r9
bls24_317_fp_add:                       # @bls24_317_fp_add
        push    r15
        push    r14
        push    rbx
        mov     rax, qword ptr [rdx + 32]
        mov     r9, qword ptr [rdx + 24]
        mov     r8, qword ptr [rdx + 16]
        mov     rcx, qword ptr [rdx]
        mov     rdx, qword ptr [rdx + 8]
        add     rcx, qword ptr [rsi]
        adc     rdx, qword ptr [rsi + 8]
        adc     r8, qword ptr [rsi + 16]
        adc     r9, qword ptr [rsi + 24]
        adc     rax, qword ptr [rsi + 32]
        movabs  rsi, 8263772892774585685
        add     rsi, rcx
        movabs  r10, 2957956877962133633
        adc     r10, rdx
        movabs  r11, -1628721857945875527
        adc     r11, r8
        movabs  rbx, 968338100789325766
        adc     rbx, r9
        movabs  r14, -1177913551803681069
        adc     r14, rax
        mov     r15, r14
        sar     r15, 63
        cmovs   r10, rdx
        cmovs   rbx, r9
        cmovs   r11, r8
        cmovs   r14, rax
        cmovs   rsi, rcx
        mov     qword ptr [rdi + 32], r14
        mov     qword ptr [rdi + 16], r11
        mov     qword ptr [rdi + 24], rbx
        mov     qword ptr [rdi], rsi
        mov     qword ptr [rdi + 8], r10
        pop     rbx
        pop     r14
        pop     r15
bls12_381_fp_add:                       # @bls12_381_fp_add
        push    r15
        push    r14
        push    r13
        push    r12
        push    rbx
        mov     r8, qword ptr [rdx + 40]
        mov     rax, qword ptr [rdx + 32]
        mov     r10, qword ptr [rdx + 24]
        mov     r9, qword ptr [rdx + 16]
        mov     rcx, qword ptr [rdx]
        mov     r11, qword ptr [rdx + 8]
        add     rcx, qword ptr [rsi]
        adc     r11, qword ptr [rsi + 8]
        adc     r9, qword ptr [rsi + 16]
        adc     r10, qword ptr [rsi + 24]
        adc     rax, qword ptr [rsi + 32]
        adc     r8, qword ptr [rsi + 40]
        movabs  rdx, 5044313057631688021
        add     rdx, rcx
        movabs  rsi, -2210141511517208576
        adc     rsi, r11
        movabs  rbx, -7435674573564081701
        adc     rbx, r9
        movabs  r14, -7239337960414712512
        adc     r14, r10
        movabs  r15, -5412103778470702296
        adc     r15, rax
        movabs  r12, -1873798617647539867
        adc     r12, r8
        mov     r13, r12
        sar     r13, 63
        cmovs   rsi, r11
        cmovs   r14, r10
        cmovs   rbx, r9
        cmovs   r12, r8
        cmovs   r15, rax
        cmovs   rdx, rcx
        mov     qword ptr [rdi + 32], r15
        mov     qword ptr [rdi + 40], r12
        mov     qword ptr [rdi + 16], rbx
        mov     qword ptr [rdi + 24], r14
        mov     qword ptr [rdi], rdx
        mov     qword ptr [rdi + 8], rsi
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     r15


For i256, the code is optimal and after 4 sub/sbb (or 4 add/adc with negated inputs) we directly have a conditional move sequence:

        movabs  rsi, -4332616871279656263
        add     rsi, r8
        movabs  r9, 7529619929231668594
        adc     r9, rdx
        movabs  r10, 5165552122434856866
        adc     r10, rcx
        movabs  r11, -3486998266802970666
        adc     r11, rax
        cmovs   r9, rdx
        cmovs   r11, rax
        cmovs   r10, rcx
        cmovs   rsi, r8

However for i320 or i384, an unnecessary SAR instruction get added

        add     rsi, rcx
        movabs  r10, 2957956877962133633
        adc     r10, rdx
        movabs  r11, -1628721857945875527
        adc     r11, r8
        movabs  rbx, 968338100789325766
        adc     rbx, r9
        movabs  r14, -1177913551803681069
        adc     r14, rax
        mov     r15, r14
        sar     r15, 63
        cmovs   r10, rdx
        cmovs   rbx, r9
        cmovs   r11, r8
        cmovs   r14, rax
        cmovs   rsi, rcx
This is an alternative implementation of LLVM modular addition from https://github.com/llvm/llvm-project/issues/103717 that uses raw LLVM IR instead of the builtin `llvm.usub.with.overflow.iXXX` The code for i256 is optimal but not for i320 or i384 (similar to the previous issue, there seems to be a size threshold after which LLVM gives up removing redundant instructions). https://alive2.llvm.org/ce/z/g_nP8g ## Full code ### Original IR ```llvm ; ModuleID = 'x86_poc' ; target triple = "arm64" target triple = "x86_64" @bn254_snarks_fp_mod = constant i256 21888242871839275222246405745257275088696311157297823662689037894645226208583, section "ctt.bn254_snarks_fp.constants", align 64 @bls12_381_fp_mod = constant i384 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787, section "ctt.bls12_381_fp.constants", align 64 @bls24_317_fp_mod = constant i320 136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051, section "ctt.bls24_317_fp.constants", align 64 ; Function Attrs: hot define internal fastcc void @_modadd_noo.u64x4(ptr %0, ptr %1, ptr %2, ptr %3) #2 section "ctt.fields" { %a = load i256, ptr %1, align 4 %b = load i256, ptr %2, align 4 %M = load i256, ptr %3, align 4 %a_plus_b = add i256 %a, %b %5 = sub i256 %a_plus_b, %M %6 = lshr i256 %5, 255 %7 = trunc i256 %6 to i1 %8 = select i1 %7, i256 %a_plus_b, i256 %5 store i256 %8, ptr %0, align 4 ret void } ; Function Attrs: hot define internal fastcc void @_modadd_noo.u64x5(ptr %0, ptr %1, ptr %2, ptr %3) #2 section "ctt.fields" { %a = load i320, ptr %1, align 4 %b = load i320, ptr %2, align 4 %M = load i320, ptr %3, align 4 %a_plus_b = add i320 %a, %b %5 = sub i320 %a_plus_b, %M %6 = lshr i320 %5, 319 %7 = trunc i320 %6 to i1 %8 = select i1 %7, i320 %a_plus_b, i320 %5 store i320 %8, ptr %0, align 4 ret void } ; Function Attrs: hot define internal fastcc void @_modadd_noo.u64x6(ptr %0, ptr %1, ptr %2, ptr %3) #2 section "ctt.fields" { %a = load i384, ptr %1, align 4 %b = load i384, ptr %2, align 4 %M = load i384, ptr %3, align 4 %a_plus_b = add i384 %a, %b %5 = sub i384 %a_plus_b, %M %6 = lshr i384 %5, 383 %7 = trunc i384 %6 to i1 %8 = select i1 %7, i384 %a_plus_b, i384 %5 store i384 %8, ptr %0, align 4 ret void } ; Function Attrs: hot define void @bn254_snarks_fp_add(ptr %0, ptr %1, ptr %2) #2 section "ctt.bn254_snarks_fp" { call fastcc void @_modadd_noo.u64x4(ptr %0, ptr %1, ptr %2, ptr @bn254_snarks_fp_mod) ret void } ; Function Attrs: hot define void @bls24_317_fp_add(ptr %0, ptr %1, ptr %2) #2 section "ctt.bls24_317_fp" { call fastcc void @_modadd_noo.u64x5(ptr %0, ptr %1, ptr %2, ptr @bls24_317_fp_mod) ret void } ; Function Attrs: hot define void @bls12_381_fp_add(ptr %0, ptr %1, ptr %2) #2 section "ctt.bls12_381_fp" { call fastcc void @_modadd_noo.u64x6(ptr %0, ptr %1, ptr %2, ptr @bls12_381_fp_mod) ret void } attributes #2 = { hot } ``` ### After opt -O3 ```llvm ; target triple = "arm64" target triple = "x86_64" define void @bn254_snarks_fp_add(ptr nocapture writeonly %0, ptr nocapture readonly %1, ptr nocapture readonly %2) local_unnamed_addr #0 section "ctt.bn254_snarks_fp" { %.val = load i256, ptr %1, align 4 %.val1 = load i256, ptr %2, align 4 %a_plus_b.i = add i256 %.val1, %.val %4 = add i256 %a_plus_b.i, -21888242871839275222246405745257275088696311157297823662689037894645226208583 %.not1.i = icmp slt i256 %4, 0 %5 = select i1 %.not1.i, i256 %a_plus_b.i, i256 %4 store i256 %5, ptr %0, align 4 ret void } define void @bls24_317_fp_add(ptr nocapture writeonly %0, ptr nocapture readonly %1, ptr nocapture readonly %2) local_unnamed_addr #0 section "ctt.bls24_317_fp" { %.val = load i320, ptr %1, align 4 %.val1 = load i320, ptr %2, align 4 %a_plus_b.i = add i320 %.val1, %.val %4 = add i320 %a_plus_b.i, -136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051 %.not1.i = icmp slt i320 %4, 0 %5 = select i1 %.not1.i, i320 %a_plus_b.i, i320 %4 store i320 %5, ptr %0, align 4 ret void } define void @bls12_381_fp_add(ptr nocapture writeonly %0, ptr nocapture readonly %1, ptr nocapture readonly %2) local_unnamed_addr #0 section "ctt.bls12_381_fp" { %.val = load i384, ptr %1, align 4 %.val1 = load i384, ptr %2, align 4 %a_plus_b.i = add i384 %.val1, %.val %4 = add i384 %a_plus_b.i, -4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 %.not1.i = icmp slt i384 %4, 0 %5 = select i1 %.not1.i, i384 %a_plus_b.i, i384 %4 store i384 %5, ptr %0, align 4 ret void } attributes #0 = { hot mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite) } ``` ### Assembly ```asm bn254_snarks_fp_add: # @bn254_snarks_fp_add mov rax, qword ptr [rdx + 24] mov rcx, qword ptr [rdx + 16] mov r8, qword ptr [rdx] mov rdx, qword ptr [rdx + 8] add r8, qword ptr [rsi] adc rdx, qword ptr [rsi + 8] adc rcx, qword ptr [rsi + 16] adc rax, qword ptr [rsi + 24] movabs rsi, -4332616871279656263 add rsi, r8 movabs r9, 7529619929231668594 adc r9, rdx movabs r10, 5165552122434856866 adc r10, rcx movabs r11, -3486998266802970666 adc r11, rax cmovs r9, rdx cmovs r11, rax cmovs r10, rcx cmovs rsi, r8 mov qword ptr [rdi + 16], r10 mov qword ptr [rdi + 24], r11 mov qword ptr [rdi], rsi mov qword ptr [rdi + 8], r9 ret bls24_317_fp_add: # @bls24_317_fp_add push r15 push r14 push rbx mov rax, qword ptr [rdx + 32] mov r9, qword ptr [rdx + 24] mov r8, qword ptr [rdx + 16] mov rcx, qword ptr [rdx] mov rdx, qword ptr [rdx + 8] add rcx, qword ptr [rsi] adc rdx, qword ptr [rsi + 8] adc r8, qword ptr [rsi + 16] adc r9, qword ptr [rsi + 24] adc rax, qword ptr [rsi + 32] movabs rsi, 8263772892774585685 add rsi, rcx movabs r10, 2957956877962133633 adc r10, rdx movabs r11, -1628721857945875527 adc r11, r8 movabs rbx, 968338100789325766 adc rbx, r9 movabs r14, -1177913551803681069 adc r14, rax mov r15, r14 sar r15, 63 cmovs r10, rdx cmovs rbx, r9 cmovs r11, r8 cmovs r14, rax cmovs rsi, rcx mov qword ptr [rdi + 32], r14 mov qword ptr [rdi + 16], r11 mov qword ptr [rdi + 24], rbx mov qword ptr [rdi], rsi mov qword ptr [rdi + 8], r10 pop rbx pop r14 pop r15 ret bls12_381_fp_add: # @bls12_381_fp_add push r15 push r14 push r13 push r12 push rbx mov r8, qword ptr [rdx + 40] mov rax, qword ptr [rdx + 32] mov r10, qword ptr [rdx + 24] mov r9, qword ptr [rdx + 16] mov rcx, qword ptr [rdx] mov r11, qword ptr [rdx + 8] add rcx, qword ptr [rsi] adc r11, qword ptr [rsi + 8] adc r9, qword ptr [rsi + 16] adc r10, qword ptr [rsi + 24] adc rax, qword ptr [rsi + 32] adc r8, qword ptr [rsi + 40] movabs rdx, 5044313057631688021 add rdx, rcx movabs rsi, -2210141511517208576 adc rsi, r11 movabs rbx, -7435674573564081701 adc rbx, r9 movabs r14, -7239337960414712512 adc r14, r10 movabs r15, -5412103778470702296 adc r15, rax movabs r12, -1873798617647539867 adc r12, r8 mov r13, r12 sar r13, 63 cmovs rsi, r11 cmovs r14, r10 cmovs rbx, r9 cmovs r12, r8 cmovs r15, rax cmovs rdx, rcx mov qword ptr [rdi + 32], r15 mov qword ptr [rdi + 40], r12 mov qword ptr [rdi + 16], rbx mov qword ptr [rdi + 24], r14 mov qword ptr [rdi], rdx mov qword ptr [rdi + 8], rsi pop rbx pop r12 pop r13 pop r14 pop r15 ret ``` ## Analysis For i256, the code is optimal and after 4 sub/sbb (or 4 add/adc with negated inputs) we directly have a conditional move sequence: ```asm movabs rsi, -4332616871279656263 add rsi, r8 movabs r9, 7529619929231668594 adc r9, rdx movabs r10, 5165552122434856866 adc r10, rcx movabs r11, -3486998266802970666 adc r11, rax cmovs r9, rdx cmovs r11, rax cmovs r10, rcx cmovs rsi, r8 ``` However for i320 or i384, an unnecessary `SAR` instruction get added ```asm add rsi, rcx movabs r10, 2957956877962133633 adc r10, rdx movabs r11, -1628721857945875527 adc r11, r8 movabs rbx, 968338100789325766 adc rbx, r9 movabs r14, -1177913551803681069 adc r14, rax mov r15, r14 sar r15, 63 cmovs r10, rdx cmovs rbx, r9 cmovs r11, r8 cmovs r14, rax cmovs rsi, rcx ```
RKSimon commented 4 weeks ago
  t169: i64,i8 = usubo_carry t85, t117, t168:1
      t171: i64 = srl t169, Constant:i8<63>
    t172: i8 = truncate t171
  t184: i8 = and t172, Constant:i8<1>
            t137: i64 = select t184, t78, t135

I think if we'd not ended up with this truncate we'd have detected the signbit test