JuliaLang / julia

The Julia Programming Language
https://julialang.org/
MIT License
45.73k stars 5.48k forks source link

LLVM vector intrinsic support for SVE? #40308

Open chriselrod opened 3 years ago

chriselrod commented 3 years ago

It'd be great to get good support for SVE, especially as SVE2 will become standard for ARMv9.

However, early tests with using LLVM vector intrinsics on the A64FX did not go well. Here is a minimal example on Godbolt, showing a vectorized (but not unrolled) dot product on the A64FX, which has 512 bit vectors. The problem is that <8 x double> gets translated into 4x <2 x double> NEON instructions, instead of an SVE instruction. v registers are NEON, and see see that the single @llvm.fma.v8f64 was broken up into 4 separate fmla instructions. Based on this document, SVE registers would be denoted by z[0-31].

This makes me wonder if to actually get intrinsic support for SVE, if we'd need to use <vscale x 2 x double>, etc, instead? This isn't compelling in Julia (unlike C/C++/wherever folks distribute binaries), since we're probably compiling for the specific target machine anyway, and can easily find the appropriate vector length using @llvm.vscale.i64.

Furthermore, we don't have any way to represent that at the moment. NTuple{L,Core.VecElement{T}} <-> <L x T>, but there's no vscale version at the moment.

Anyone have any insight into/knowledge about this?

yuyichao commented 3 years ago

The first part is a LLVM issue and should be discussed there. I feel like it's illegal though. And the biggest obstacle to represent scalable vector in julia code is that the side is unknown. There are a lot of places that assumes the size of a type is known at compile time or it'll be heap allocated. That's the main reason I did not implement it 4-5 years ago.

chriselrod commented 3 years ago

The first part is a LLVM issue and should be discussed there. I feel like it's illegal though.

Would the llvm/dev mailing list be the best place to ask?

And the biggest obstacle to represent scalable vector in julia code is that the side is unknown. There are a lot of places that assumes the size of a type is known at compile time or it'll be heap allocated. That's the main reason I did not implement it 4-5 years ago.

Yeah, I think it'd need special handling / a representation like it has in LLVM IR.

chriselrod commented 3 years ago

https://lists.llvm.org/pipermail/llvm-dev/2021-April/149612.html suggested specifying -aarch64-sve-vector-bits-min=, which does work: https://godbolt.org/z/Mo76oWanW Aside from the option not yet existing in LLVM 11, we'd need a way to set this correctly (and preferably automatically) when starting Julia. This sounds (again) like something best handled on the LLVM side of things, e.g. a -aarch64-sve-vector-bits=native option. Although were something like that to be added, it wouldn't be available until at least LLVM 13.

giordano commented 3 years ago

After playing a bit with Chris on an a64fx cluster with Julia v1.7-beta3 (which comes with LLVM 12):

$ julia -q
julia> using BenchmarkTools

julia> function sumsimd(x)
           s = zero(eltype(x))
           @simd for xi in x
               s += xi
           end
           s
       end
sumsimd (generic function with 1 method)

julia> @btime sumsimd(x) setup=(x = rand(1_000_000))
  643.256 μs (0 allocations: 0 bytes)
500273.11451950937

julia> 
$ JULIA_LLVM_ARGS="-aarch64-sve-vector-bits-min=512" julia -q
julia> using BenchmarkTools

julia> function sumsimd(x)
           s = zero(eltype(x))
           @simd for xi in x
               s += xi
           end
           s
       end
sumsimd (generic function with 1 method)

julia> @btime sumsimd(x) setup=(x = rand(1_000_000))
  185.212 μs (0 allocations: 0 bytes)
500240.6910755522

It's sufficient to set JULIA_LLVM_ARGS appropriately to get a nice 3.5x boost.

milankl commented 2 years ago

@giordano Have you also checked what happens if you use Float32 or Float16? i.e.

julia> @btime sumsimd(x) setup=(x = rand(Float32,1_000_000)) julia> @btime sumsimd(x) setup=(x = rand(Float16,1_000_000))

whereas for the latter I assume Float16 must be enabled (#40216) which I assume is only for 1.8 onwards default

giordano commented 2 years ago

Yop, I did in https://github.com/UoB-HPC/BabelStream/pull/106#discussion_r697861796:

$ JULIA_LLVM_ARGS="-aarch64-sve-vector-bits-min=512" julia -q
julia> using BenchmarkTools

julia> function sumsimd(x)
           s = zero(eltype(x))
           @simd for xi in x
               s += xi
           end
           s
       end
sumsimd (generic function with 1 method)

julia> @btime sumsimd(x) setup=(x = randn(Float64, 1_000_000))
  191.912 μs (0 allocations: 0 bytes)
1853.0335322487956

julia> @btime sumsimd(x) setup=(x = randn(Float32, 1_000_000))
  80.330 μs (0 allocations: 0 bytes)
400.9806f0

julia> @btime sumsimd(x) setup=(x = randn(Float16, 1_000_000))
  42.761 μs (0 allocations: 0 bytes)
Float16(1.872e3)
giordano commented 2 years ago

@chriselrod

However, early tests with using LLVM vector intrinsics on the A64FX did not go well. Here is a minimal example on Godbolt, showing a vectorized (but not unrolled) dot product on the A64FX, which has 512 bit vectors. The problem is that <8 x double> gets translated into 4x <2 x double> NEON instructions, instead of an SVE instruction. v registers are NEON, and see see that the single @llvm.fma.v8f64 was broken up into 4 separate fmla instructions. Based on this document, SVE registers would be denoted by z[0-31].

Using llc trunk (but already v13 should be enough) with -march=aarch64 -mcpu=a64fx -aarch64-sve-vector-bits-min=512: https://godbolt.org/z/ovxhr933G. I think it looks much better?

chriselrod commented 2 years ago
.LBB0_2: // %L34
  ld1d { z1.d }, p0/z, [x13]
  ld1d { z2.d }, p0/z, [x14]
  add x10, x10, #8
  add x14, x14, #64
  add x13, x13, #64
  fmla z0.d, p0/m, z1.d, z2.d
  cmp x10, x12
  b.le .LBB0_2

Yeah, that looks good.

giordano commented 2 years ago

The good news is that Julia nightly

Julia Version 1.9.0-DEV.809
Commit 9b83dd8920 (2022-06-19 19:31 UTC)
Platform Info:
  OS: Linux (aarch64-unknown-linux-gnu)
  CPU: 48 × unknown
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.3 (ORCJIT, a64fx)
  Threads: 1 on 48 virtual cores

generates the code (see below) for the sumsimd function above with the new registries without having to use JULIA_LLVM_ARGS="-aarch64-sve-vector-bits-min=512" (option which is a massive pain because you'd run into all sorts of crashes every now and then, like #44401 and #44263).

julia> @code_llvm debuginfo=:none sumsimd(randn(Float64, 1_000_000))
define double @julia_sumsimd_827({}* nonnull align 16 dereferenceable(40) %0) #0 {
top:
  %1 = bitcast {}* %0 to { i8*, i64, i16, i16, i32 }*
  %2 = getelementptr inbounds { i8*, i64, i16, i16, i32 }, { i8*, i64, i16, i16, i32 }* %1, i64 0, i32 1
  %3 = load i64, i64* %2, align 8
  %.not = icmp eq i64 %3, 0
  br i1 %.not, label %L17, label %L10.lr.ph

L10.lr.ph:                                        ; preds = %top
  %4 = bitcast {}* %0 to double**
  %5 = load double*, double** %4, align 8
  %6 = call i64 @llvm.vscale.i64()
  %7 = shl i64 %6, 3
  %min.iters.check = icmp ult i64 %3, %7
  br i1 %min.iters.check, label %scalar.ph, label %vector.ph

vector.ph:                                        ; preds = %L10.lr.ph
  %n.mod.vf = urem i64 %3, %7
  %n.vec = sub nsw i64 %3, %n.mod.vf
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %vector.ph
  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %vec.phi = phi <vscale x 2 x double> [ insertelement (<vscale x 2 x double> zeroinitializer, double 0.000000e+00, i32 0), %vector.ph ], [ %23, %vector.body ]
  %vec.phi9 = phi <vscale x 2 x double> [ zeroinitializer, %vector.ph ], [ %24, %vector.body ]
  %vec.phi10 = phi <vscale x 2 x double> [ zeroinitializer, %vector.ph ], [ %25, %vector.body ]
  %vec.phi11 = phi <vscale x 2 x double> [ zeroinitializer, %vector.ph ], [ %26, %vector.body ]
  %8 = getelementptr inbounds double, double* %5, i64 %index
  %9 = bitcast double* %8 to <vscale x 2 x double>*
  %wide.load = load <vscale x 2 x double>, <vscale x 2 x double>* %9, align 8
  %10 = call i32 @llvm.vscale.i32()
  %11 = shl i32 %10, 1
  %12 = sext i32 %11 to i64
  %13 = getelementptr inbounds double, double* %8, i64 %12
  %14 = bitcast double* %13 to <vscale x 2 x double>*
  %wide.load12 = load <vscale x 2 x double>, <vscale x 2 x double>* %14, align 8
  %15 = shl i32 %10, 2
  %16 = sext i32 %15 to i64
  %17 = getelementptr inbounds double, double* %8, i64 %16
  %18 = bitcast double* %17 to <vscale x 2 x double>*
  %wide.load13 = load <vscale x 2 x double>, <vscale x 2 x double>* %18, align 8
  %19 = mul i32 %10, 6
  %20 = sext i32 %19 to i64
  %21 = getelementptr inbounds double, double* %8, i64 %20
  %22 = bitcast double* %21 to <vscale x 2 x double>*
  %wide.load14 = load <vscale x 2 x double>, <vscale x 2 x double>* %22, align 8
  %23 = fadd fast <vscale x 2 x double> %vec.phi, %wide.load
  %24 = fadd fast <vscale x 2 x double> %vec.phi9, %wide.load12
  %25 = fadd fast <vscale x 2 x double> %vec.phi10, %wide.load13
  %26 = fadd fast <vscale x 2 x double> %vec.phi11, %wide.load14
  %index.next = add nuw i64 %index, %7
  %27 = icmp eq i64 %index.next, %n.vec
  br i1 %27, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  %bin.rdx = fadd fast <vscale x 2 x double> %24, %23
  %bin.rdx15 = fadd fast <vscale x 2 x double> %25, %bin.rdx
  %bin.rdx16 = fadd fast <vscale x 2 x double> %26, %bin.rdx15
  %28 = call fast double @llvm.vector.reduce.fadd.nxv2f64(double -0.000000e+00, <vscale x 2 x double> %bin.rdx16)
  %cmp.n = icmp eq i64 %n.mod.vf, 0
  br i1 %cmp.n, label %L17, label %scalar.ph

scalar.ph:                                        ; preds = %middle.block, %L10.lr.ph
  %bc.resume.val = phi i64 [ %n.vec, %middle.block ], [ 0, %L10.lr.ph ]
  %bc.merge.rdx = phi double [ %28, %middle.block ], [ 0.000000e+00, %L10.lr.ph ]
  br label %L10

L10:                                              ; preds = %L10, %scalar.ph
  %value_phi18 = phi i64 [ %bc.resume.val, %scalar.ph ], [ %32, %L10 ]
  %value_phi7 = phi double [ %bc.merge.rdx, %scalar.ph ], [ %31, %L10 ]
  %29 = getelementptr inbounds double, double* %5, i64 %value_phi18
  %30 = load double, double* %29, align 8
  %31 = fadd fast double %value_phi7, %30
  %32 = add nuw nsw i64 %value_phi18, 1
  %exitcond.not = icmp eq i64 %32, %3
  br i1 %exitcond.not, label %L17, label %L10

L17:                                              ; preds = %L10, %middle.block, %top
  %value_phi2 = phi double [ 0.000000e+00, %top ], [ %28, %middle.block ], [ %31, %L10 ]
  ret double %value_phi2
}
julia> @code_native debuginfo=:none sumsimd(randn(Float64, 1_000_000))
        .text
        .file   "sumsimd"
        .globl  julia_sumsimd_831               // -- Begin function julia_sumsimd_831
        .p2align        3
        .type   julia_sumsimd_831,@function
julia_sumsimd_831:                      // @julia_sumsimd_831
        .cfi_startproc
// %bb.0:                               // %top
        ldr     x8, [x0, #8]
        cbz     x8, .LBB0_3
// %bb.1:                               // %L10.lr.ph
        ldr     x9, [x0]
        cnth    x11
        cmp     x8, x11
        b.hs    .LBB0_4
// %bb.2:
        mov     x10, xzr
        movi    d0, #0000000000000000
        b       .LBB0_7
.LBB0_3:
        movi    d0, #0000000000000000
                                        // kill: def $d0 killed $d0 killed $z0
        ret
.LBB0_4:                                // %vector.ph
        udiv    x10, x8, x11
        movi    d1, #0000000000000000
        mov     z0.d, #0                        // =0x0
        ptrue   p0.d, vl1
        cntd    x13
        cntw    x15
        cntd    x16, all, mul #3
        mov     x12, xzr
        add     x14, x9, w13, sxtw #3
        add     x15, x9, w15, sxtw #3
        add     x16, x9, w16, sxtw #3
        sel     z1.d, p0, z1.d, z0.d
        mov     z2.d, z0.d
        mov     z3.d, z0.d
        ptrue   p0.d
        mul     x10, x10, x11
        sub     x13, x8, x10
        .p2align        2
.LBB0_5:                                // %vector.body
                                        // =>This Inner Loop Header: Depth=1
        ld1d    { z4.d }, p0/z, [x9, x12, lsl #3]
        ld1d    { z5.d }, p0/z, [x14, x12, lsl #3]
        fadd    z1.d, z1.d, z4.d
        ld1d    { z6.d }, p0/z, [x15, x12, lsl #3]
        ld1d    { z7.d }, p0/z, [x16, x12, lsl #3]
        fadd    z0.d, z0.d, z5.d
        fadd    z2.d, z2.d, z6.d
        fadd    z3.d, z3.d, z7.d
        add     x12, x12, x11
        cmp     x12, x10
        b.ne    .LBB0_5
// %bb.6:                               // %middle.block
        fadd    z0.d, z0.d, z1.d
        fadd    z0.d, z2.d, z0.d
        fadd    z0.d, z3.d, z0.d
        faddv   d0, p0, z0.d
        cbz     x13, .LBB0_9
.LBB0_7:                                // %L10.preheader
        sub     x8, x8, x10
        add     x9, x9, x10, lsl #3
        .p2align        2
.LBB0_8:                                // %L10
                                        // =>This Inner Loop Header: Depth=1
        ldr     d1, [x9], #8
        fadd    d0, d0, d1
        subs    x8, x8, #1
        b.ne    .LBB0_8
.LBB0_9:                                // %L17
                                        // kill: def $d0 killed $d0 killed $z0
        ret
.Lfunc_end0:
        .size   julia_sumsimd_831, .Lfunc_end0-julia_sumsimd_831
        .cfi_endproc
                                        // -- End function
        .section        ".note.GNU-stack","",@progbits
julia> @code_llvm debuginfo=:none sumsimd(randn(Float16, 1_000_000))
define half @julia_sumsimd_840({}* nonnull align 16 dereferenceable(40) %0) #0 {
top:
  %1 = bitcast {}* %0 to { i8*, i64, i16, i16, i32 }*
  %2 = getelementptr inbounds { i8*, i64, i16, i16, i32 }, { i8*, i64, i16, i16, i32 }* %1, i64 0, i32 1
  %3 = load i64, i64* %2, align 8
  %.not = icmp eq i64 %3, 0
  br i1 %.not, label %L17, label %L10.lr.ph

L10.lr.ph:                                        ; preds = %top
  %4 = bitcast {}* %0 to half**
  %5 = load half*, half** %4, align 8
  %6 = call i64 @llvm.vscale.i64()
  %7 = shl i64 %6, 5
  %min.iters.check = icmp ult i64 %3, %7
  br i1 %min.iters.check, label %scalar.ph, label %vector.ph

vector.ph:                                        ; preds = %L10.lr.ph
  %n.mod.vf = urem i64 %3, %7
  %n.vec = sub nsw i64 %3, %n.mod.vf
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %vector.ph
  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %vec.phi = phi <vscale x 8 x half> [ insertelement (<vscale x 8 x half> zeroinitializer, half 0xH0000, i32 0), %vector.ph ], [ %23, %vector.body ]
  %vec.phi9 = phi <vscale x 8 x half> [ zeroinitializer, %vector.ph ], [ %24, %vector.body ]
  %vec.phi10 = phi <vscale x 8 x half> [ zeroinitializer, %vector.ph ], [ %25, %vector.body ]
  %vec.phi11 = phi <vscale x 8 x half> [ zeroinitializer, %vector.ph ], [ %26, %vector.body ]
  %8 = getelementptr inbounds half, half* %5, i64 %index
  %9 = bitcast half* %8 to <vscale x 8 x half>*
  %wide.load = load <vscale x 8 x half>, <vscale x 8 x half>* %9, align 2
  %10 = call i32 @llvm.vscale.i32()
  %11 = shl i32 %10, 3
  %12 = sext i32 %11 to i64
  %13 = getelementptr inbounds half, half* %8, i64 %12
  %14 = bitcast half* %13 to <vscale x 8 x half>*
  %wide.load12 = load <vscale x 8 x half>, <vscale x 8 x half>* %14, align 2
  %15 = shl i32 %10, 4
  %16 = sext i32 %15 to i64
  %17 = getelementptr inbounds half, half* %8, i64 %16
  %18 = bitcast half* %17 to <vscale x 8 x half>*
  %wide.load13 = load <vscale x 8 x half>, <vscale x 8 x half>* %18, align 2
  %19 = mul i32 %10, 24
  %20 = sext i32 %19 to i64
  %21 = getelementptr inbounds half, half* %8, i64 %20
  %22 = bitcast half* %21 to <vscale x 8 x half>*
  %wide.load14 = load <vscale x 8 x half>, <vscale x 8 x half>* %22, align 2
  %23 = fadd fast <vscale x 8 x half> %vec.phi, %wide.load
  %24 = fadd fast <vscale x 8 x half> %vec.phi9, %wide.load12
  %25 = fadd fast <vscale x 8 x half> %vec.phi10, %wide.load13
  %26 = fadd fast <vscale x 8 x half> %vec.phi11, %wide.load14
  %index.next = add nuw i64 %index, %7
  %27 = icmp eq i64 %index.next, %n.vec
  br i1 %27, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  %bin.rdx = fadd fast <vscale x 8 x half> %24, %23
  %bin.rdx15 = fadd fast <vscale x 8 x half> %25, %bin.rdx
  %bin.rdx16 = fadd fast <vscale x 8 x half> %26, %bin.rdx15
  %28 = call fast half @llvm.vector.reduce.fadd.nxv8f16(half 0xH8000, <vscale x 8 x half> %bin.rdx16)
  %cmp.n = icmp eq i64 %n.mod.vf, 0
  br i1 %cmp.n, label %L17, label %scalar.ph

scalar.ph:                                        ; preds = %middle.block, %L10.lr.ph
  %bc.resume.val = phi i64 [ %n.vec, %middle.block ], [ 0, %L10.lr.ph ]
  %bc.merge.rdx = phi half [ %28, %middle.block ], [ 0xH0000, %L10.lr.ph ]
  br label %L10

L10:                                              ; preds = %L10, %scalar.ph
  %value_phi18 = phi i64 [ %bc.resume.val, %scalar.ph ], [ %32, %L10 ]
  %value_phi7 = phi half [ %bc.merge.rdx, %scalar.ph ], [ %31, %L10 ]
  %29 = getelementptr inbounds half, half* %5, i64 %value_phi18
  %30 = load half, half* %29, align 2
  %31 = fadd fast half %value_phi7, %30
  %32 = add nuw nsw i64 %value_phi18, 1
  %exitcond.not = icmp eq i64 %32, %3
  br i1 %exitcond.not, label %L17, label %L10

L17:                                              ; preds = %L10, %middle.block, %top
  %value_phi2 = phi half [ 0xH0000, %top ], [ %28, %middle.block ], [ %31, %L10 ]
  ret half %value_phi2
}
julia> @code_native debuginfo=:none sumsimd(randn(Float16, 1_000_000))
        .text
        .file   "sumsimd"
        .globl  julia_sumsimd_842               // -- Begin function julia_sumsimd_842
        .p2align        3
        .type   julia_sumsimd_842,@function
julia_sumsimd_842:                      // @julia_sumsimd_842
        .cfi_startproc
// %bb.0:                               // %top
        ldr     x8, [x0, #8]
        cbz     x8, .LBB0_3
// %bb.1:                               // %L10.lr.ph
        ldr     x9, [x0]
        rdvl    x11, #2
        cmp     x8, x11
        b.hs    .LBB0_4
// %bb.2:
        mov     x10, xzr
        movi    d0, #0000000000000000
        b       .LBB0_7
.LBB0_3:
        movi    d0, #0000000000000000
                                        // kill: def $h0 killed $h0 killed $z0
        ret
.LBB0_4:                                // %vector.ph
        udiv    x10, x8, x11
        movi    d1, #0000000000000000
        mov     z0.h, #0                        // =0x0
        ptrue   p0.h, vl1
        cnth    x13
        rdvl    x15, #1
        cnth    x16, all, mul #3
        mov     x12, xzr
        add     x14, x9, w13, sxtw #1
        add     x15, x9, w15, sxtw #1
        add     x16, x9, w16, sxtw #1
        sel     z1.h, p0, z1.h, z0.h
        mov     z2.d, z0.d
        mov     z3.d, z0.d
        ptrue   p0.h
        mul     x10, x10, x11
        sub     x13, x8, x10
        .p2align        2
.LBB0_5:                                // %vector.body
                                        // =>This Inner Loop Header: Depth=1
        ld1h    { z4.h }, p0/z, [x9, x12, lsl #1]
        ld1h    { z5.h }, p0/z, [x14, x12, lsl #1]
        fadd    z1.h, z1.h, z4.h
        ld1h    { z6.h }, p0/z, [x15, x12, lsl #1]
        ld1h    { z7.h }, p0/z, [x16, x12, lsl #1]
        fadd    z0.h, z0.h, z5.h
        fadd    z2.h, z2.h, z6.h
        fadd    z3.h, z3.h, z7.h
        add     x12, x12, x11
        cmp     x12, x10
        b.ne    .LBB0_5
// %bb.6:                               // %middle.block
        fadd    z0.h, z0.h, z1.h
        fadd    z0.h, z2.h, z0.h
        fadd    z0.h, z3.h, z0.h
        faddv   h0, p0, z0.h
        cbz     x13, .LBB0_9
.LBB0_7:                                // %L10.preheader
        sub     x8, x8, x10
        add     x9, x9, x10, lsl #1
        .p2align        2
.LBB0_8:                                // %L10
                                        // =>This Inner Loop Header: Depth=1
        ldr     h1, [x9], #2
        fadd    h0, h0, h1
        subs    x8, x8, #1
        b.ne    .LBB0_8
.LBB0_9:                                // %L17
                                        // kill: def $h0 killed $h0 killed $z0
        ret
.Lfunc_end0:
        .size   julia_sumsimd_842, .Lfunc_end0-julia_sumsimd_842
        .cfi_endproc
                                        // -- End function
        .section        ".note.GNU-stack","",@progbits

Performance is exactly the same, with and without JULIA_LLVM_ARGS="--aarch64-sve-vector-bits-min=512". This is a huge improvement in terms of usability (until I'll get into new exciting crashes).

gbaraldi commented 2 years ago

vscale registers, fancy.

vchuravy commented 2 years ago

Nice to see that LLVM 14 is good for something!

gbaraldi commented 2 years ago

Now we just need to emit Float16 instructions outside of fastmath