henry2004y / Vlasiator.jl

Data processor for Vlasiator
https://henry2004y.github.io/Vlasiator.jl/stable/
MIT License
6 stars 4 forks source link

SIMD #170

Closed henry2004y closed 5 months ago

henry2004y commented 5 months ago

I just learned that using the @simd macro is necessary in some loops to achieve optimal performance. Consider the following example:

julia> function mysum_basic(a::Vector)
           total = zero(eltype(a))
           for x in a
               total += x
           end
           return total
       end
mysum_basic (generic function with 1 method)

julia> function mysum_simd(a::Vector)
           total = zero(eltype(a))
           @simd for x in a
               total += x
           end
           return total
       end
mysum_simd (generic function with 1 method)

julia> using BenchmarkTools

julia> rand_array_1D = rand(1000000)

julia> begin
           @btime mysum_basic($rand_array_1D)
           @btime mysum_simd($rand_array_1D)
       end
  814.200 μs (0 allocations: 0 bytes)
  118.500 μs (0 allocations: 0 bytes)

8x difference!

The native code can be checked via @code_native:

julia> @code_native debuginfo=:none dump_module=false mysum_basic(rand_array_1D)
        .text
        push    rbp
        mov     rbp, rsp
        mov     rax, qword ptr [r13 + 16]
        mov     rax, qword ptr [rax + 16]
        mov     rax, qword ptr [rax]
        mov     rdx, qword ptr [rcx + 8]
        test    rdx, rdx
        je      L86
        mov     rax, qword ptr [rcx]
        vxorpd  xmm0, xmm0, xmm0
        vaddsd  xmm0, xmm0, qword ptr [rax]
        cmp     rdx, 1
        je      L208
        lea     r9, [rdx - 1]
        add     rdx, -2
        mov     r8d, r9d
        and     r8d, 7
        cmp     rdx, 7
        jae     L92
        mov     ecx, 2
        mov     edx, 1
        test    r8, r8
        jne     L192
        jmp     L208
L86:
        vxorps  xmm0, xmm0, xmm0
        pop     rbp
        ret
L92:
        and     r9, -8
        mov     edx, 1
        mov     rcx, -2
        nop     dword ptr [rax]
L112:
        vaddsd  xmm0, xmm0, qword ptr [rax + 8*rdx]
        vaddsd  xmm0, xmm0, qword ptr [rax + 8*rdx + 8]
        vaddsd  xmm0, xmm0, qword ptr [rax + 8*rdx + 16]
        vaddsd  xmm0, xmm0, qword ptr [rax + 8*rdx + 24]
        vaddsd  xmm0, xmm0, qword ptr [rax + 8*rdx + 32]
        vaddsd  xmm0, xmm0, qword ptr [rax + 8*rdx + 40]
        vaddsd  xmm0, xmm0, qword ptr [rax + 8*rdx + 48]
        vaddsd  xmm0, xmm0, qword ptr [rax + 8*rdx + 56]
        add     rdx, 8
        lea     r10, [r9 + rcx]
        add     r10, -8
        add     rcx, -8
        cmp     r10, -2
        jne     L112
        neg     rcx
        test    r8, r8
        je      L208
        nop     dword ptr [rax]
L192:
        vaddsd  xmm0, xmm0, qword ptr [rax + 8*rdx]
        mov     rdx, rcx
        inc     rcx
        dec     r8
        jne     L192
L208:
        pop     rbp
        ret
        nop     word ptr cs:[rax + rax]

versus

julia> @code_native debuginfo=:none dump_module=false mysum_simd(rand_array_1D)
        .text
        push    rbp
        mov     rbp, rsp
        push    rsi
        mov     rax, qword ptr [r13 + 16]
        mov     rax, qword ptr [rax + 16]
        mov     rax, qword ptr [rax]
        mov     r9, qword ptr [rcx + 8]
        test    r9, r9
        je      L45
        mov     rcx, qword ptr [rcx]
        cmp     r9, 16
        jae     L54
        vxorpd  xmm0, xmm0, xmm0
        xor     edx, edx
        jmp     L416
L45:
        vxorpd  xmm0, xmm0, xmm0
        jmp     L429
L54:
        mov     rdx, r9
        and     rdx, -16
        lea     rsi, [rdx - 16]
        mov     r10, rsi
        shr     r10, 4
        inc     r10
        mov     r8d, r10d
        and     r8d, 3
        movabs  rax, offset .rodata.cst8
        movabs  r11, offset .rodata.cst32
        vbroadcastsd    ymm0, qword ptr [rax]
        vmovapd ymm1, ymmword ptr [r11]
        cmp     rsi, 48
        jae     L133
        xor     eax, eax
        vmovapd ymm2, ymm0
        vmovapd ymm3, ymm0
        jmp     L305
L133:
        and     r10, -4
        xor     eax, eax
        vmovapd ymm2, ymm0
        vmovapd ymm3, ymm0
        nop     word ptr cs:[rax + rax]
L160:
        vaddpd  ymm1, ymm1, ymmword ptr [rcx + 8*rax]
        vaddpd  ymm0, ymm0, ymmword ptr [rcx + 8*rax + 32]
        vaddpd  ymm2, ymm2, ymmword ptr [rcx + 8*rax + 64]
        vaddpd  ymm3, ymm3, ymmword ptr [rcx + 8*rax + 96]
        vaddpd  ymm1, ymm1, ymmword ptr [rcx + 8*rax + 128]
        vaddpd  ymm0, ymm0, ymmword ptr [rcx + 8*rax + 160]
        vaddpd  ymm2, ymm2, ymmword ptr [rcx + 8*rax + 192]
        vaddpd  ymm3, ymm3, ymmword ptr [rcx + 8*rax + 224]
        vaddpd  ymm1, ymm1, ymmword ptr [rcx + 8*rax + 256]
        vaddpd  ymm0, ymm0, ymmword ptr [rcx + 8*rax + 288]
        vaddpd  ymm2, ymm2, ymmword ptr [rcx + 8*rax + 320]
        vaddpd  ymm3, ymm3, ymmword ptr [rcx + 8*rax + 352]
        vaddpd  ymm1, ymm1, ymmword ptr [rcx + 8*rax + 384]
        vaddpd  ymm0, ymm0, ymmword ptr [rcx + 8*rax + 416]
        vaddpd  ymm2, ymm2, ymmword ptr [rcx + 8*rax + 448]
        vaddpd  ymm3, ymm3, ymmword ptr [rcx + 8*rax + 480]
        add     rax, 64
        add     r10, -4
        jne     L160
L305:
        test    r8, r8
        je      L372
        lea     r10, [rcx + 8*rax]
        add     r10, 96
        shl     r8, 7
        xor     eax, eax
        nop     word ptr cs:[rax + rax]
L336:
        vaddpd  ymm1, ymm1, ymmword ptr [r10 + rax - 96]
        vaddpd  ymm0, ymm0, ymmword ptr [r10 + rax - 64]
        vaddpd  ymm2, ymm2, ymmword ptr [r10 + rax - 32]
        vaddpd  ymm3, ymm3, ymmword ptr [r10 + rax]
        sub     rax, -128
        cmp     r8, rax
        jne     L336
L372:
        vaddpd  ymm0, ymm0, ymm1
        vaddpd  ymm0, ymm2, ymm0
        vaddpd  ymm0, ymm3, ymm0
        vextractf128    xmm1, ymm0, 1
        vaddpd  xmm0, xmm0, xmm1
        vpermilpd       xmm1, xmm0, 1           # xmm1 = xmm0[1,0]
        vaddsd  xmm0, xmm0, xmm1
        cmp     r9, rdx
        je      L429
        nop     dword ptr [rax]
L416:
        vaddsd  xmm0, xmm0, qword ptr [rcx + 8*rdx]
        inc     rdx
        cmp     r9, rdx
        jne     L416
L429:
        pop     rsi
        pop     rbp
        vzeroupper
        ret
        nop     word ptr cs:[rax + rax]

This means that while the Julia compiler can automatically add some SIMD instructions, we still need to try @simd to make full use of the potential capabilities, if possible.