llvm / llvm-project

The LLVM Project is a collection of modular and reusable compiler and toolchain technologies.
http://llvm.org
Other
27.92k stars 11.52k forks source link

[x86] Chained `blsr`s can be optimized to a `pdep` (since Haswell on Intel, since Zen 3 on AMD) #101915

Open Validark opened 1 month ago

Validark commented 1 month ago

Here is how we could compile a number of blsr's known at compile-time:

const NUM_BLSR_OPS = 8;

export fn chained_blsr_ops_1(x: u64) u64 {
    var r = x;
    inline for (0..NUM_BLSR_OPS) |_| {
        r &= r -% 1;
    }
    return r;
}

export fn chained_blsr_ops_2(x: u64) u64 {
    if (NUM_BLSR_OPS >= 64) return 0;
    return pdep(~@as(u64, 0) << NUM_BLSR_OPS, x);
}

fn pdep(a: u64, b: u64) u64 {
    return struct {
        extern fn @"llvm.x86.bmi.pdep.64"(u64, u64) u64;
    }.@"llvm.x86.bmi.pdep.64"(a, b);
}
chained_blsr_ops_1:  ; bad version
        blsr    rax, rdi
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        ret

chained_blsr_ops_2:  ; good version
        mov     rax, -256
        pdep    rax, rax, rdi
        ret

Here is how we could compile a loop of blsr's:

export fn chained_blsr_ops_1(x: u64, num_blsr_ops: u64) u64 {
    var r = x;
    for (0..num_blsr_ops) |_| {
        r &= r -% 1;
    }
    return r;
}

export fn chained_blsr_ops_2(x: u64, num_blsr_ops: u64) u64 {
    const r = pdep(~@as(u64, 0) << @truncate(num_blsr_ops), x);
    return if (num_blsr_ops >= 64) 0 else r;
}

fn pdep(a: u64, b: u64) u64 {
    return struct {
        extern fn @"llvm.x86.bmi.pdep.64"(u64, u64) u64;
    }.@"llvm.x86.bmi.pdep.64"(a, b);
}
chained_blsr_ops_1:   ; bad version
        mov     rax, rdi
        test    rsi, rsi
        je      .LBB0_6
        mov     ecx, esi
        and     ecx, 7
        cmp     rsi, 8
        jb      .LBB0_4
        and     rsi, -8
.LBB0_3:
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        blsr    rax, rax
        add     rsi, -8
        jne     .LBB0_3
.LBB0_4:
        test    rcx, rcx
        je      .LBB0_6
.LBB0_5:
        blsr    rax, rax
        dec     rcx
        jne     .LBB0_5
.LBB0_6:
        ret

chained_blsr_ops_2:   ; good version
        mov     rax, -1
        shlx    rax, rax, rsi
        pdep    rcx, rax, rdi
        xor     eax, eax
        cmp     rsi, 64
        cmovb   rax, rcx
        ret
llvmbot commented 1 month ago

@llvm/issue-subscribers-backend-x86

Author: Niles Salter (Validark)

Here is how we could compile a number of blsr's known at compile-time: ```zig const NUM_BLSR_OPS = 8; export fn chained_blsr_ops_1(x: u64) u64 { var r = x; inline for (0..NUM_BLSR_OPS) |_| { r &= r -% 1; } return r; } export fn chained_blsr_ops_2(x: u64) u64 { if (NUM_BLSR_OPS >= 64) return 0; return pdep(~@as(u64, 0) << NUM_BLSR_OPS, x); } fn pdep(a: u64, b: u64) u64 { return struct { extern fn @"llvm.x86.bmi.pdep.64"(u64, u64) u64; }.@"llvm.x86.bmi.pdep.64"(a, b); } ``` ```asm chained_blsr_ops_1: ; bad version blsr rax, rdi blsr rax, rax blsr rax, rax blsr rax, rax blsr rax, rax blsr rax, rax blsr rax, rax blsr rax, rax ret chained_blsr_ops_2: ; good version mov rax, -256 pdep rax, rax, rdi ret ``` Here is how we could compile a loop of blsr's: ```zig export fn chained_blsr_ops_1(x: u64, num_blsr_ops: u64) u64 { var r = x; for (0..num_blsr_ops) |_| { r &= r -% 1; } return r; } export fn chained_blsr_ops_2(x: u64, num_blsr_ops: u64) u64 { const r = pdep(~@as(u64, 0) << @truncate(num_blsr_ops), x); return if (num_blsr_ops >= 64) 0 else r; } fn pdep(a: u64, b: u64) u64 { return struct { extern fn @"llvm.x86.bmi.pdep.64"(u64, u64) u64; }.@"llvm.x86.bmi.pdep.64"(a, b); } ``` ```asm chained_blsr_ops_1: ; bad version mov rax, rdi test rsi, rsi je .LBB0_6 mov ecx, esi and ecx, 7 cmp rsi, 8 jb .LBB0_4 and rsi, -8 .LBB0_3: blsr rax, rax blsr rax, rax blsr rax, rax blsr rax, rax blsr rax, rax blsr rax, rax blsr rax, rax blsr rax, rax add rsi, -8 jne .LBB0_3 .LBB0_4: test rcx, rcx je .LBB0_6 .LBB0_5: blsr rax, rax dec rcx jne .LBB0_5 .LBB0_6: ret chained_blsr_ops_2: ; good version mov rax, -1 shlx rax, rax, rsi pdep rcx, rax, rdi xor eax, eax cmp rsi, 64 cmovb rax, rcx ret ```