llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
27.92k stars 11.52k forks source link

[x86 / Clang] Pessimization: missed fusion of substraction/compare/cmov after -O1 optimization or extra caller. #102868

Open mratsim opened 1 month ago

mratsim commented 1 month ago

This is a follow-up to my quest to do efficient cross-ISA modular arithmetic for cryptography and finding a workaround to #102062

Ignoring the loads/stores, the following LLVM IR is optimized into an optimal sequence of adc,sbb,cmov with

when compiled with Clang

https://alive2.llvm.org/ce/z/hnQycG

but it becomes adc,sbb,cmov,adc for a 33% more compute instructions with the following IR.

image

LLVM IR

; ModuleID = 'x86_poc'
source_filename = "x86_poc"
target triple = "x86_64-pc-linux-gnu"

@bn254_snarks_fp_mod = constant i256 21888242871839275222246405745257275088696311157297823662689037894645226208583, section "ctt.bn254_snarks_fp.constants", align 64
@bn254_snarks_fr_mod = constant i256 21888242871839275222246405745257275088548364400416034343698204186575808495617, section "ctt.bn254_snarks_fr.constants", align 64

define void @bn254_snarks_fp_add(ptr %0, ptr %1, ptr %2) section "bn254_snarks_fp" {
  call fastcc void @_modadd_noo_u64x4(ptr %0, ptr %1, ptr %2, ptr @bn254_snarks_fp_mod)
  ret void
}

define internal fastcc void @_modadd_noo_u64x4(ptr %0, ptr %1, ptr %2, ptr %3) section "ctt.fields" {
  %a = load i256, ptr %1, align 4
  %b = load i256, ptr %2, align 4
  %a_plus_b = add i256 %a, %b
  %5 = alloca [4 x i64], align 8
  store i256 %a_plus_b, ptr %5, align 4
  call fastcc void @_finalsub_noo_u64x4(ptr %0, ptr %5, ptr %3)
  ret void
}

define internal fastcc void @_finalsub_noo_u64x4(ptr %0, ptr %1, ptr %2) section "ctt.fields" {
  %M = load i256, ptr %2, align 4
  %a = load i256, ptr %1, align 4
  %a_minus_M = sub i256 %a, %M
  %borrow = icmp ult i256 %a, %M
  %4 = select i1 %borrow, i256 %a, i256 %a_minus_M
  store i256 %4, ptr %0, align 4
  ret void
}

; Comment this out for good codegen (or use Clang -O0)
define void @bn254_snarks_fr_add(ptr %0, ptr %1, ptr %2) section "bn254_snarks_fr" {
  call fastcc void @_modadd_noo_u64x4(ptr %0, ptr %1, ptr %2, ptr @bn254_snarks_fr_mod)
  ret void
}

Clang -O1

bn254_snarks_fp_add:                    # @bn254_snarks_fp_add
        push    rbx
        mov     rax, qword ptr [rdx]
        add     rax, qword ptr [rsi]
        mov     rcx, qword ptr [rdx + 8]
        adc     rcx, qword ptr [rsi + 8]
        mov     r8, qword ptr [rdx + 16]
        adc     r8, qword ptr [rsi + 16]
        mov     rdx, qword ptr [rdx + 24]
        adc     rdx, qword ptr [rsi + 24]
        xor     esi, esi
        movabs  r9, 4332616871279656263
        cmp     rax, r9
        movabs  r9, -7529619929231668595
        mov     r10, rcx
        sbb     r10, r9
        movabs  r9, -5165552122434856867
        mov     r10, r8
        sbb     r10, r9
        movabs  r9, 3486998266802970665
        mov     r10, rdx
        sbb     r10, r9
        movabs  r9, -3486998266802970666
        cmovb   r9, rsi
        movabs  r10, 5165552122434856866
        cmovb   r10, rsi
        movabs  r11, 7529619929231668594
        cmovb   r11, rsi
        movabs  rbx, -4332616871279656263
        cmovb   rbx, rsi
        add     rbx, rax    # This sequence shouldn't exist
        adc     r11, rcx
        adc     r10, r8
        adc     r9, rdx
        mov     qword ptr [rdi], rbx
        mov     qword ptr [rdi + 8], r11
        mov     qword ptr [rdi + 16], r10
        mov     qword ptr [rdi + 24], r9
        pop     rbx
        ret
bn254_snarks_fr_add:                    # @bn254_snarks_fr_add
        push    rbx
        mov     rax, qword ptr [rdx]
        add     rax, qword ptr [rsi]
        mov     rcx, qword ptr [rdx + 8]
        adc     rcx, qword ptr [rsi + 8]
        mov     r8, qword ptr [rdx + 16]
        adc     r8, qword ptr [rsi + 16]
        mov     rdx, qword ptr [rdx + 24]
        adc     rdx, qword ptr [rsi + 24]
        xor     esi, esi
        movabs  r9, 4891460686036598785
        cmp     rax, r9
        movabs  r9, 2896914383306846353
        mov     r10, rcx
        sbb     r10, r9
        movabs  r9, -5165552122434856867
        mov     r10, r8
        sbb     r10, r9
        movabs  r9, 3486998266802970665
        mov     r10, rdx
        sbb     r10, r9
        movabs  r9, -3486998266802970666
        cmovb   r9, rsi
        movabs  r10, 5165552122434856866
        cmovb   r10, rsi
        movabs  r11, -2896914383306846354
        cmovb   r11, rsi
        movabs  rbx, -4891460686036598785
        cmovb   rbx, rsi
        add     rbx, rax
        adc     r11, rcx
        adc     r10, r8
        adc     r9, rdx
        mov     qword ptr [rdi], rbx
        mov     qword ptr [rdi + 8], r11
        mov     qword ptr [rdi + 16], r10
        mov     qword ptr [rdi + 24], r9
        pop     rbx
        ret
bn254_snarks_fp_mod:
        .quad   4332616871279656263
        .quad   -7529619929231668595
        .quad   -5165552122434856867
        .quad   3486998266802970665

bn254_snarks_fr_mod:
        .quad   4891460686036598785
        .quad   2896914383306846353
        .quad   -5165552122434856867
        .quad   3486998266802970665

Clang -O0

bn254_snarks_fp_add:                    # @bn254_snarks_fp_add
        push    rax
        mov     rcx, qword ptr [rip + bn254_snarks_fp_mod@GOTPCREL]
        call    _modadd_noo_u64x4
        pop     rax
        ret
_modadd_noo_u64x4:                      # @_modadd_noo_u64x4
        push    rbx
        sub     rsp, 48
        mov     qword ptr [rsp + 8], rcx        # 8-byte Spill
        mov     r11, rdx
        mov     rdx, qword ptr [rsp + 8]        # 8-byte Reload
        mov     rax, qword ptr [rsi + 24]
        mov     rcx, qword ptr [rsi + 16]
        mov     r8, qword ptr [rsi]
        mov     rsi, qword ptr [rsi + 8]
        mov     r9, qword ptr [r11 + 24]
        mov     r10, qword ptr [r11 + 16]
        mov     rbx, qword ptr [r11]
        mov     r11, qword ptr [r11 + 8]
        add     r8, rbx
        adc     rsi, r11
        adc     rcx, r10
        adc     rax, r9
        mov     qword ptr [rsp + 16], r8
        mov     qword ptr [rsp + 24], rsi
        mov     qword ptr [rsp + 32], rcx
        mov     qword ptr [rsp + 40], rax
        lea     rsi, [rsp + 16]
        call    _finalsub_noo_u64x4
        add     rsp, 48
        pop     rbx
        ret
_finalsub_noo_u64x4:                    # @_finalsub_noo_u64x4
        push    rbx
        mov     rax, rsi
        mov     rbx, qword ptr [rdx + 24]
        mov     rsi, qword ptr [rdx + 16]
        mov     rcx, qword ptr [rdx]
        mov     rdx, qword ptr [rdx + 8]
        mov     r8, qword ptr [rax + 24]
        mov     r9, qword ptr [rax + 16]
        mov     r11, qword ptr [rax]
        mov     r10, qword ptr [rax + 8]
        mov     rax, r11
        sub     rax, rcx
        mov     rcx, r10
        sbb     rcx, rdx
        mov     rdx, r9
        sbb     rdx, rsi
        mov     rsi, r8
        sbb     rsi, rbx
        cmovb   rax, r11
        cmovb   rcx, r10
        cmovb   rdx, r9
        cmovb   rsi, r8
        mov     qword ptr [rdi + 24], rsi
        mov     qword ptr [rdi + 16], rdx
        mov     qword ptr [rdi + 8], rcx
        mov     qword ptr [rdi], rax
        pop     rbx
        ret
bn254_snarks_fr_add:                    # @bn254_snarks_fr_add
        push    rax
        mov     rcx, qword ptr [rip + bn254_snarks_fr_mod@GOTPCREL]
        call    _modadd_noo_u64x4
        pop     rax
        ret
bn254_snarks_fp_mod:
        .quad   4332616871279656263
        .quad   -7529619929231668595
        .quad   -5165552122434856867
        .quad   3486998266802970665

bn254_snarks_fr_mod:
        .quad   4891460686036598785
        .quad   2896914383306846353
        .quad   -5165552122434856867
        .quad   3486998266802970665

Clang -O1 but with only a single proc

edit: actually it is unoptimal, the first add is repeated

bn254_snarks_fp_add:                    # @bn254_snarks_fp_add
        push    r14
        push    rbx
        mov     rcx, qword ptr [rdx + 24]
        mov     rax, qword ptr [rdx]
        add     rax, qword ptr [rsi]
        mov     r8, qword ptr [rdx + 8]
        adc     r8, qword ptr [rsi + 8]
        mov     rdx, qword ptr [rdx + 16]
        adc     rdx, qword ptr [rsi + 16]
        adc     rcx, qword ptr [rsi + 24]
        movabs  rsi, -4332616871279656263
        add     rsi, rax    # edit: actually this sequence is unnecessary
        movabs  r9, 7529619929231668594
        adc     r9, r8
        movabs  r10, 5165552122434856866
        adc     r10, rdx
        movabs  r11, -3486998266802970666
        adc     r11, rcx
        movabs  rbx, 4332616871279656263
        cmp     rax, rbx
        movabs  rbx, -7529619929231668595
        mov     r14, r8
        sbb     r14, rbx
        movabs  rbx, -5165552122434856867
        mov     r14, rdx
        sbb     r14, rbx
        movabs  rbx, 3486998266802970665
        mov     r14, rcx
        sbb     r14, rbx
        cmovb   r9, r8
        cmovb   r11, rcx
        cmovb   r10, rdx
        cmovb   rsi, rax
        mov     qword ptr [rdi + 16], r10
        mov     qword ptr [rdi + 24], r11
        mov     qword ptr [rdi], rsi
        mov     qword ptr [rdi + 8], r9
        pop     rbx
        pop     r14
        ret
bn254_snarks_fp_mod:
        .quad   4332616871279656263
        .quad   -7529619929231668595
        .quad   -5165552122434856867
        .quad   3486998266802970665

Extra context

Modular addition is critical to optimize for cryptography and is used everywhere (HTTPS / authentication to websites). Currently state-of-the-art libraries have to use assembly for both speed and correctness reasons (constant-time) and improved compiler support is important for more robust software and also usage on wider hardware (WASM, GPUs, FPGAs, ...)

Reproduction

I used to be able to reproduce it with llc and opt, opt moved the select one instruction apart from the borrow and reordered instructions so that either 0 or he modulus was added. Unfortunately I lost the exact fiddling that produced that output.

mratsim commented 1 month ago

It seems from playing with LLVMRunPasses

that the scc-oz-module-inliner pass triggers this behavior reliably

image

https://alive2.llvm.org/ce/z/htetw5

Original LLVM IR

; ModuleID = 'x86_poc'
source_filename = "x86_poc"
target triple = "x86_64-pc-linux-gnu"

@bn254_snarks_fp_mod = constant i256 21888242871839275222246405745257275088696311157297823662689037894645226208583, section "ctt.bn254_snarks_fp.constants", align 64
@bn254_snarks_fr_mod = constant i256 21888242871839275222246405745257275088548364400416034343698204186575808495617, section "ctt.bn254_snarks_fr.constants", align 64

define void @bn254_snarks_fp_add(ptr %0, ptr %1, ptr %2) section "bn254_snarks_fp" {
  call fastcc void @_modadd_noo_u64x4(ptr %0, ptr %1, ptr %2, ptr @bn254_snarks_fp_mod)
  ret void
}

define internal fastcc void @_modadd_noo_u64x4(ptr %0, ptr %1, ptr %2, ptr %3) section "ctt.fields" {
  %a = load i256, ptr %1, align 4
  %b = load i256, ptr %2, align 4
  %a_plus_b = add i256 %a, %b
  %5 = alloca [4 x i64], align 8
  store i256 %a_plus_b, ptr %5, align 4
  call fastcc void @_finalsub_noo_u64x4(ptr %0, ptr %5, ptr %3)
  ret void
}

define internal fastcc void @_finalsub_noo_u64x4(ptr %0, ptr %1, ptr %2) section "ctt.fields" {
  %M = load i256, ptr %2, align 4
  %a = load i256, ptr %1, align 4
  %a_minus_M = sub i256 %a, %M
  %borrow = icmp ult i256 %a, %M
  %4 = select i1 %borrow, i256 %a, i256 %a_minus_M
  store i256 %4, ptr %0, align 4
  ret void
}

; Comment this out for good codegen (or use Clang -O0)
define void @bn254_snarks_fr_add(ptr %0, ptr %1, ptr %2) section "bn254_snarks_fr" {
  call fastcc void @_modadd_noo_u64x4(ptr %0, ptr %1, ptr %2, ptr @bn254_snarks_fr_mod)
  ret void
}

Pessimized

target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-pc-linux-gnu"

define void @bn254_snarks_fp_add(ptr nocapture writeonly %0, ptr nocapture readonly %1, ptr nocapture readonly %2) #0 section "bn254_snarks_fp" {
  %a.i = load i256, ptr %1, align 4
  %b.i = load i256, ptr %2, align 4
  %a_plus_b.i = add i256 %b.i, %a.i
  %borrow.i.i = icmp ult i256 %a_plus_b.i, 21888242871839275222246405745257275088696311157297823662689037894645226208583
  %a_minus_M.i.i.neg = select i1 %borrow.i.i, i256 0, i256 -21888242871839275222246405745257275088696311157297823662689037894645226208583
  %4 = add i256 %a_minus_M.i.i.neg, %a_plus_b.i
  store i256 %4, ptr %0, align 4
  ret void
}

define void @bn254_snarks_fr_add(ptr nocapture writeonly %0, ptr nocapture readonly %1, ptr nocapture readonly %2) #0 section "bn254_snarks_fr" {
  %a.i = load i256, ptr %1, align 4
  %b.i = load i256, ptr %2, align 4
  %a_plus_b.i = add i256 %b.i, %a.i
  %borrow.i.i = icmp ult i256 %a_plus_b.i, 21888242871839275222246405745257275088548364400416034343698204186575808495617
  %a_minus_M.i.i.neg = select i1 %borrow.i.i, i256 0, i256 -21888242871839275222246405745257275088548364400416034343698204186575808495617
  %4 = add i256 %a_minus_M.i.i.neg, %a_plus_b.i
  store i256 %4, ptr %0, align 4
  ret void
}

attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite) }