halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.78k stars 1.07k forks source link

Code generation for interleaves of broadcasts #5069

Open joesavage opened 4 years ago

joesavage commented 4 years ago

I've been doing some prototyping recently around Halide GEMM kernels for AArch64, and in doing so, seem to have uncovered some code generation issues around interleaves of broadcasts.

To illustrate this, I've put together the following test case:

#include "Halide.h"

using namespace Halide;

int main(int argc, char **argv)
{
    Func AB("AB");
    Var y("y"), x("x");

    // Declare arguments
    ImageParam A_in(type_of<uint8_t>(), 2, "A_in");
    ImageParam B_in(type_of<uint8_t>(), 2, "B_in");
    std::vector<Argument> args(2);
    args[0] = A_in;
    args[1] = B_in;

    // Calculate
    RDom rv(0, 4);
    AB(x, y) += cast(UInt(32), A_in(rv, y)) * B_in(rv, x);

    // Set constraints
    OutputImageParam output = AB.output_buffer();
    output.dim(0).set_bounds(0, 4);
    A_in.dim(1).set_min(0);
    B_in.dim(1).set_min(0);

    // Schedule
    RVar fused("fused");
    AB.update(0).reorder(rv, x, y).fuse(rv, x, fused);
    AB.update(0).atomic().vectorize(fused);

    // Compile for bare metal AArch64
    Target target(Target::NoOS, Target::ARM, 64);
    std::vector<Target::Feature> arm_features;
    arm_features.push_back(Target::ARMDotProd);
    target.set_features(arm_features);
    AB.compile_to_file("AB", args, "", target);

    return 0;
}

If we compile this with the atomic_vectorization branch from #4628, we end up with something like the following Halide IR in the inner loop:

AB[ramp(t22, 1, 4)] += (uint32x4)vector_reduce(Add, (uint32x16(interleave_vectors(x4((uint8)A_in[t21]),
                                                                                  x4((uint8)A_in[t21 + 1]),
                                                                                  x4((uint8)A_in[t21 + 2]),
                                                                                  x4((uint8)A_in[t21 + 3]))) *
                                                     uint32x16(interleave_vectors((uint8x4)B_in[ramp(t25, B_in.stride.1, 4)],
                                                                                  (uint8x4)B_in[ramp(t26, B_in.stride.1, 4)],
                                                                                  (uint8x4)B_in[ramp(t27, B_in.stride.1, 4)],
                                                                                  (uint8x4)B_in[ramp(t28, B_in.stride.1, 4)]))))

Zooming in on the first input to the multiply, with elements loaded from A_in, we see that the interleave here actually represents a repeating pattern. We might intuitively think about this as x4(ramp(t21, 1, 4)), where the resultant vector contains the pattern A_in[t21], A_in[t21 + 1], A_in[t21 + 2] repeated four times. Since Halide doesn't yet support generating multi-dimensional broadcasts like this, however, we're stuck with this slightly odd interleave-of-broadcasts representation.

Unfortunately, as a result of this, the compiler as of today generates the following LLVM IR:

%load.a = load i8, i8* %348, align 1, !tbaa !179                                       // i.e. | a |
%insr.a = insertelement <4 x i8> undef, i8 %load.a, i32 0                              // i.e. | a | - | - | - |
%dupl.a = shufflevector <4 x i8> %insert.a, <4 x i8> undef, <4 x i32> zeroinitializer  // i.e. | a | a | a | a |
%addr.b = getelementptr inbounds i8, i8* %348, i64 1
%load.b = load i8, i8* %addr.b, align 1, !tbaa !179
%insr.b = insertelement <4 x i8> undef, i8 %load.b, i32 0
%dupl.b = shufflevector <4 x i8> %insr.b, <4 x i8> undef, <4 x i32> zeroinitializer
%addr.c = getelementptr inbounds i8, i8* %348, i64 2
%load.c = load i8, i8* %addr.c, align 1, !tbaa !179
%insr.c = insertelement <4 x i8> undef, i8 %load.c, i32 0
%dupl.c = shufflevector <4 x i8> %insr.c, <4 x i8> undef, <4 x i32> zeroinitializer
%addr.d = getelementptr inbounds i8, i8* %348, i64 3
%load.d = load i8, i8* %addr.d, align 1, !tbaa !179
%insr.d = insertelement <4 x i8> undef, i8 %load.d, i32 0
%dupl.d = shufflevector <4 x i8> %insr.d, <4 x i8> undef, <4 x i32> zeroinitializer

%shuffle.a.c = shufflevector <4 x i8> %dupl.a, <4 x i8> %dupl.c, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7> // i.e. | a | c | a | c | a | c | a | c |
%shuffle.b.d = shufflevector <4 x i8> %dupl.b, <4 x i8> %dupl.d, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7> // i.e. | b | d | b | d | b | d | b | d |
%shuffle.a.b.c.d = shufflevector <8 x i8> %shuffle.a.c, <8 x i8> %shuffle.b.d, <16 x i32> <i32 0, i32 8, i32 1, i32 9, i32 2, i32 10, i32 3, i32 11, i32 4, i32 12, i32 5, i32 13, i32 6, i32 14, i32 7, i32 15>
// => | a | b | c | d | a | b | c | d | a | b | c | d | a | b | c | d |

This vastly over-complicates what is really quite a simple operation. As a result, I'm seeing instructions like ldrb, zip1, zip2, and individual lane loads and moves in the final assembly, when we really just want a vector load followed by some indexing. In this case, probably something like the following LLVM IR:

%load.composite = load i32, i32* %351, align 4, !tbaa !179                                             // i.e. || a | b | c | d ||
%insr.composite = insertelement <4 x i32> undef, i32 %load.composite, i32 0                            // i.e. || a | b | c | d || - | - | - | - || - | - | - | - || - | - | - | - ||
%dupl.composite = shufflevector <4 x i32> %insr.composite, <4 x i32> undef, <4 x i32> zeroinitializer  // i.e. || a | b | c | d || a | b | c | d || a | b | c | d || a | b | c | d ||
%shuffle.a.b.c.d = bitcast <4 x i32> %dupl.composite to <16 x i8>
// => | a | b | c | d | a | b | c | d | a | b | c | d | a | b | c | d |

In my prototyping, I'm currently working around this with some hacked together code in the back-end that detects shuffles of this type and emits the right IR. However, given that this is quite a complex pattern to peephole optimise, it feels like maybe it should belong somewhere else, perhaps in simplification. Of course, I don't think that's possible today without supporting multi-dimensional broadcasts or casts between vectors of different lengths, but maybe it's a better long term solution.

Does anyone have any thoughts on how this should be improved, or indeed whether there's a better way around this that I'm not seeing? If it's doesn't end up being a huge task, I'm happy to work on this myself.

dsharletg commented 4 years ago

However, given that this is quite a complex pattern to peephole optimise, it feels like maybe it should belong somewhere else, perhaps in simplification. Of course, I don't think that's possible today without supporting multi-dimensional broadcasts or casts between vectors of different lengths, but maybe it's a better long term solution.

This is very close to being merged (#4873). I think once this is merged, the best way to fix this is for the simplifier to rewrite interleave(broadcast(x1), broadcast(x2), ...) to broadcast(interleave(x1, x2, ...)).

We might then still need to make sure that generates good code, but this seems more straightforward, no pattern matching required. It also seems familiar from vrmpy codegen on Hexagon, which does work (concats of scalar loads do get optimized the way we need them to here, using the same mechanism of "interleaving" the scalars).