BlueBrain / nmodl

Code Generation Framework For NEURON MODeling Language
https://bluebrain.github.io/nmodl/
Apache License 2.0
55 stars 15 forks source link

Support SVE/SVE2 targets in LLVM code generation pipeline #637

Open georgemitenkov opened 3 years ago

georgemitenkov commented 3 years ago

Example of generating simple SVE: https://godbolt.org/z/jc8v8vbrx

To fully support SVE backends, there are (at least) three points that need to be taken care of:

Support scalable vector types

LLVM has a FixedVectorType and a ScalableVectorType. The first one is intended for the vector with width known at compile-time. For SVE, we need to use the latter.

Support induction variable increment for vectorised code

Since there is no width at compile time, there should be a separate handling of induction variable increment. LLVM has @llvm.vscale.i64 to handle that (https://llvm.org/docs/LangRef.html#llvm-vscale-intrinsic).

Support scalable vector constants

Since the width is not known at compile time, vector constants also have to be handled separately. Currently, this is done using shufflevector.

Some links: https://hps.vi4io.org/_media/events/2020/llvm-cth20_lovett.pdf

georgemitenkov commented 3 years ago

The llvm.vscale intrinsic returns the value for vscale in scalable vectors such as <vscale x 4 x i8>. Hence, we need to multiply the result by the multiple of the vector width, here 4.

georgemitenkov commented 3 years ago

Constants:

For constant vector values, we cannot specify all the elements as we can for
fixed-length vectors; fortunately only a small number of easily synthesized
patterns are required for autovectorization. The `zeroinitializer` constant
can be used as with fixed-length vectors for a constant zero splat. This can
then be combined with `insertelement` and `shufflevector` to create arbitrary
value splats in the same manner as fixed-length vectors.
castigli commented 3 years ago

As I was mentioning earlier, I implemented a simple daxpy kernel with acle that might be interesting as a reference of SVE asm code. https://godbolt.org/z/PMTP8c67q You can also force the VL to be fixed by adding the -msve-vector-bits=xxx flag and uncommenting the typedef

// -march=armv8-a+sve -std=c++17 -O3 
// -msve-vector-bits=256

#include <array>
#include <cassert>

#ifdef __ARM_FEATURE_SVE
#include <arm_acle.h>
#include <arm_sve.h>
#endif

#include <iostream>

// requires clang > 12.0.0 or gcc > 10
// requires -msve-vector-bits=<length>
// #if __ARM_FEATURE_SVE_BITS > 0
// typedef svfloat64_t vec __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
// typedef svbool_t pred __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
// #endif 

namespace scalar
{
void daxpy(double a, double* x, double* y, std::size_t size, double* ret)
{
    for(std::size_t i=0; i < size; ++i)
    {
        ret[i] = a * x[i] + y[i];
    }
}
}

#ifdef __ARM_FEATURE_SVE
namespace acle
{
void daxpy(double a, double* x, double* y, std::size_t size, double* ret)
{
    svfloat64_t x_sve;
    svfloat64_t y_sve;
    svfloat64_t ret_sve;
    svbool_t pg;
    svbool_t all = svptrue_b64();

    for(std::size_t i=0; svptest_first(all, pg = svwhilelt_b64(i, size)); i+=svcntd())
    {
        // ret[i] = a * x[i] + y[i];
        x_sve = svld1(pg, &x[i]);
        y_sve = svld1(pg, &y[i]);
        ret_sve = svmla_x(pg, y_sve, x_sve, a); 
        svst1(pg, &ret[i], ret_sve);
    }
}
}
#endif

int main()
{   
    constexpr std::size_t size = 5;
    double a = 1;
    std::array<double, size> x = {1,2,3,4,5};
    std::array<double, size> y = {-2,-4,-6,-8,-10};

    std::array<double, size> res{};
    std::array<double, size> res_sve{};

#ifdef __ARM_FEATURE_SVE
    scalar::daxpy(a, x.data(), y.data(), size, res.data());
    acle::daxpy(a, x.data(), y.data(), size, res_sve.data());

    for(std::size_t i=0; i < size; ++i)
    {
        assert(res[i] == res_sve[i]);
    }
#endif
}

the interesting bits scalar (actually scalar + neon)

scalar::daxpy(double, double*, double*, unsigned long, double*):            // @scalar::daxpy(double, double*, double*, unsigned long, double*)
        cbz     x2, .LBB0_10
        cmp     x2, #4                          // =4
        b.hs    .LBB0_3
        mov     x8, xzr
        b       .LBB0_8
.LBB0_3:
        lsl     x9, x2, #3
        add     x11, x0, x9
        add     x10, x3, x9
        cmp     x11, x3
        add     x9, x1, x9
        cset    w11, hi
        cmp     x10, x0
        cset    w12, hi
        cmp     x9, x3
        cset    w9, hi
        cmp     x10, x1
        mov     x8, xzr
        and     w11, w11, w12
        cset    w10, hi
        tbnz    w11, #0, .LBB0_8
        and     w9, w9, w10
        tbnz    w9, #0, .LBB0_8
        and     x8, x2, #0xfffffffffffffffc
        dup     v1.2d, v0.d[0]
        add     x9, x3, #16                     // =16
        add     x10, x1, #16                    // =16
        add     x11, x0, #16                    // =16
        mov     x12, x8
.LBB0_6:                                // =>This Inner Loop Header: Depth=1
        ldp     q2, q3, [x11, #-16]
        ldp     q4, q5, [x10, #-16]
        subs    x12, x12, #4                    // =4
        add     x10, x10, #32                   // =32
        fmul    v2.2d, v2.2d, v1.2d
        fmul    v3.2d, v3.2d, v1.2d
        fadd    v2.2d, v2.2d, v4.2d
        fadd    v3.2d, v3.2d, v5.2d
        stp     q2, q3, [x9, #-16]
        add     x9, x9, #32                     // =32
        add     x11, x11, #32                   // =32
        b.ne    .LBB0_6
        cmp     x8, x2
        b.eq    .LBB0_10
.LBB0_8:
        lsl     x11, x8, #3
        sub     x9, x2, x8
        add     x8, x3, x11
        add     x10, x1, x11
        add     x11, x0, x11
.LBB0_9:                                // =>This Inner Loop Header: Depth=1
        ldr     d1, [x11], #8
        ldr     d2, [x10], #8
        subs    x9, x9, #1                      // =1
        fmul    d1, d1, d0
        fadd    d1, d1, d2
        str     d1, [x8], #8
        b.ne    .LBB0_9
.LBB0_10:
        ret

sve

acle::daxpy(double, double*, double*, unsigned long, double*):              // @acle::daxpy(double, double*, double*, unsigned long, double*)
        whilelo p0.d, xzr, x2
        cset    w8, mi
        cmp     w8, #1                          // =1
        b.ne    .LBB1_3
        cntd    x9
        mov     x8, xzr
        mov     z0.d, d0
        cntd    x10, all, mul #8
        mov     x11, x9
.LBB1_2:                                // =>This Inner Loop Header: Depth=1
        add     x12, x0, x8
        add     x13, x1, x8
        ld1d    { z1.d }, p0/z, [x12]
        ld1d    { z2.d }, p0/z, [x13]
        add     x12, x3, x8
        add     x8, x8, x10
        fmla    z2.d, p0/m, z1.d, z0.d
        st1d    { z2.d }, p0, [x12]
        whilelo p0.d, x11, x2
        add     x11, x11, x9
        b.mi    .LBB1_2
.LBB1_3:
        ret
georgemitenkov commented 3 years ago

@castigli Thank you for the reference! I have adapted the kernel generated from our fixed vector pipeline and it seems like SVE is generated (reading ARM assembly is not my strong side so would be great if you can confirm!)

// kernel
VOID nrn_state_test(INSTANCE_STRUCT *mech){
    INTEGER id
    for(id = 0; id<mech->node_count-1; id = id+2) {
        mech->m[id] = mech->x[id]
    }
    for(; id<mech->node_count; id = id+1) {
        mech->m[id] = mech->x[id]
    }
}

Here I am trying to generate SVE instructions, multiples of 2 (Link: https://godbolt.org/z/dWzrebbWx). Also putting LLVM IR and assembly for clarity:

; ModuleID = 'test'
source_filename = "test"

%test__instance_var__type = type { double*, double*, double*, double*, double*, double*, i32*, double, double, double, i32, i32 }

; Function Attrs: nofree nounwind
define void @nrn_state_test(%test__instance_var__type* noalias nocapture readonly %mech1) #0 {
  %mech = alloca %test__instance_var__type*, align 8
  %id = alloca i32, align 4
  store %test__instance_var__type* %mech1, %test__instance_var__type** %mech, align 8
  store i32 0, i32* %id, align 4
  br label %for.cond

for.cond:                                         ; preds = %for.inc, %0
  %1 = load %test__instance_var__type*, %test__instance_var__type** %mech, align 8
  %2 = getelementptr inbounds %test__instance_var__type, %test__instance_var__type* %1, i32 0, i32 11
  %3 = load i32, i32* %2, align 4

  ; handle scalable vector width
  %scale0 = call i32 @llvm.vscale.i32()
  %width0 = mul i32 %scale0, 2
  %bound = sub i32 %width0, 1
  %4 = sub i32 %3, %bound

  %5 = load i32, i32* %id, align 4
  %6 = icmp slt i32 %5, %4
  br i1 %6, label %for.body, label %for.exit

for.body:                                         ; preds = %for.cond
  %7 = load %test__instance_var__type*, %test__instance_var__type** %mech, align 8
  %8 = getelementptr inbounds %test__instance_var__type, %test__instance_var__type* %7, i32 0, i32 0
  %9 = load i32, i32* %id, align 4
  %10 = sext i32 %9 to i64
  %11 = load double*, double** %8, align 8
  %12 = getelementptr inbounds double, double* %11, i64 %10
  %13 = bitcast double* %12 to <vscale x 2 x double>*
  %14 = load <vscale x 2 x double>, <vscale x 2 x double>* %13
  %15 = load %test__instance_var__type*, %test__instance_var__type** %mech, align 8
  %16 = getelementptr inbounds %test__instance_var__type, %test__instance_var__type* %15, i32 0, i32 1
  %17 = load i32, i32* %id, align 4
  %18 = sext i32 %17 to i64
  %19 = load double*, double** %16, align 8
  %20 = getelementptr inbounds double, double* %19, i64 %18
  %21 = bitcast double* %20 to <vscale x 2 x double>*
  store <vscale x 2 x double> %14, <vscale x 2 x double>* %21
  br label %for.inc

for.inc:                                          ; preds = %for.body
  %22 = load i32, i32* %id, align 4

  ; handle scalable vector width
  %scale1 = call i32 @llvm.vscale.i32()
  %width1 = mul i32 %scale1, 2
  %23 = add i32 %22, %width1
  store i32 %23, i32* %id, align 4
  br label %for.cond

for.exit:                                         ; preds = %for.cond
  br label %for.cond2

for.cond2:                                        ; preds = %for.inc4, %for.exit
  %24 = load %test__instance_var__type*, %test__instance_var__type** %mech, align 8
  %25 = getelementptr inbounds %test__instance_var__type, %test__instance_var__type* %24, i32 0, i32 11
  %26 = load i32, i32* %25, align 4
  %27 = load i32, i32* %id, align 4
  %28 = icmp slt i32 %27, %26
  br i1 %28, label %for.body3, label %for.exit5

for.body3:                                        ; preds = %for.cond2
  %29 = load %test__instance_var__type*, %test__instance_var__type** %mech, align 8
  %30 = getelementptr inbounds %test__instance_var__type, %test__instance_var__type* %29, i32 0, i32 0
  %31 = load i32, i32* %id, align 4
  %32 = sext i32 %31 to i64
  %33 = load double*, double** %30, align 8
  %34 = getelementptr inbounds double, double* %33, i64 %32
  %35 = load double, double* %34, align 8
  %36 = load %test__instance_var__type*, %test__instance_var__type** %mech, align 8
  %37 = getelementptr inbounds %test__instance_var__type, %test__instance_var__type* %36, i32 0, i32 1
  %38 = load i32, i32* %id, align 4
  %39 = sext i32 %38 to i64
  %40 = load double*, double** %37, align 8
  %41 = getelementptr inbounds double, double* %40, i64 %39
  store double %35, double* %41, align 8
  br label %for.inc4

for.inc4:                                         ; preds = %for.body3
  %42 = load i32, i32* %id, align 4
  %43 = add i32 %42, 1
  store i32 %43, i32* %id, align 4
  br label %for.cond2

for.exit5:                                        ; preds = %for.cond2
  ret void
}

declare i32 @llvm.vscale.i32() #1

attributes #0 = { nofree nounwind "target-features"="+sve,+sve" }
attributes #1 = { nofree nosync nounwind readnone willreturn }
; llc -O3 -mtriple=arm64-apple-darwin
_nrn_state_test:                        ; @nrn_state_test
        sub     sp, sp, #16                     ; =16
        cntd    x8
        neg     x9, x8
        ptrue   p0.d
        str     x0, [sp, #8]
        str     wzr, [sp, #4]
LBB0_1:                                 ; %for.cond
        ldr     x10, [sp, #8]
        ldr     w11, [sp, #4]
        ldr     w10, [x10, #84]
        add     w10, w10, w9
        add     w10, w10, #1                    ; =1
        cmp     w11, w10
        b.ge    LBB0_4
        ldr     x10, [sp, #8]
        ldrsw   x11, [sp, #4]
        ldp     x12, x10, [x10]
        lsl     x11, x11, #3
        add     x12, x12, x11
        ld1d    { z0.d }, p0/z, [x12]
        add     x10, x10, x11
        st1d    { z0.d }, p0, [x10]
        ldr     w10, [sp, #4]
        add     w10, w10, w8
        str     w10, [sp, #4]
        b       LBB0_1
LBB0_3:                                 ; %for.body3
        ldr     x8, [sp, #8]
        ldrsw   x9, [sp, #4]
        ldp     x10, x8, [x8]
        lsl     x9, x9, #3
        ldr     d0, [x10, x9]
        str     d0, [x8, x9]
        ldr     w8, [sp, #4]
        add     w8, w8, #1                      ; =1
        str     w8, [sp, #4]
LBB0_4:                                 ; %for.cond2
        ldr     x8, [sp, #8]
        ldr     w9, [sp, #4]
        ldr     w8, [x8, #84]
        cmp     w9, w8
        b.lt    LBB0_3
        add     sp, sp, #16                     ; =16
        ret
castigli commented 3 years ago

I was reading a bit more, the instruction that gets the element count in the SIMD vector is

cntd    x8

so this suggest that the load base address is incremented properly (a vector width worth of elements at the time) and if the size of the array is smaller than the vector width it jumps to the scalar code directly