halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.9k stars 1.07k forks source link

LLVM arm code generation #3136

Open mohamedadaly opened 6 years ago

mohamedadaly commented 6 years ago

I am trying to force LLVM to generate vmlal instructions using scalar operands. So far I can get it to generate something like this

vdup.16 q14, d2[0]
vmlal.u16   q10, d7, d29
vmlal.u16   q8, d6, d28

which clearly should be optimized to something like this (no need for the vdup instruction):

vmlal.u16   q10, d7, d2[0]
vmlal.u16   q8, d6, d2[0]

LLVM generates the second version only when I am vectorizing the left operand by 4, and not 8, which I guess is based on the pattern here (it expects vectors of length 4 not 8). Overall, this produces inefficient code when loading from memory, because the left operand is bound_extent(8) anyway.

I guess a better question is this: is there a way to write custom LLVM IR code for a particular block of Halide IR? I know it can be done for a Func, but the issue is that the pattern I am trying to optimize materializes only after the scheduling part (after tiling, unrolling, vectorizing, etc.). I would be happy to try to implement something like this, if you could point me out to the relevant pieces that need to be put together...

mohamedadaly commented 6 years ago

I was also looking into a similar Halide generated function that is doing the same but for floating point numbers, and found that LLVM generated vfma.f32 instructions, and manually changing these into vmla.f32 with scalar operands was more than 20% faster on a Cortex A7 system.

mohamedadaly commented 6 years ago

I guess a better question is this: is there a way to write custom LLVM IR code for a particular block of Halide IR? I know it can be done for a Func, but the issue is that the pattern I am trying to optimize materializes only after the scheduling part (after tiling, unrolling, vectorizing, etc.). I would be happy to try to implement something like this, if you could point me out to the relevant pieces that need to be put together...

@abadams @zvookin Any help, or thoughts why this is (not) possible/feasible/a good idea/etc., is greatly appreciated.

abadams commented 6 years ago

If you can get the right code generated by vectorizing by a factor of 4.... can you just then unroll that by a factor of two?

abadams commented 6 years ago

For the general question, I really don't think we have a good way to implement small snippets of IR. It'd be adding a user-specific peepholing operation in the backend and then adding raw .ll to the runtime....

Fixing the compiler to generate the right IR for the scalar-second-arg variant vmlal might be feasible. You'd have to peephole recognize the pattern in the arm backend and slice things up into vectors of size 4 to pass to llvm.

mohamedadaly commented 6 years ago

Yep, I was able to get the scalar operands working with vmlal after some acrobatics with Halide. But that caused another problem. Basically, I forced the computation of a vector with compute_at that is then shuffleed to extract one element that is then broadcast, which is the pattern expected by LLVM. However, the generated code is quite weird in that in the inner kernel the vector of 4 uint8 values is read, then stored on the stack, then read again into the neon registers, and the computation is performed with scalar operands as desired. I put MemoryType::Register but it doesn't seem to be affecting this.

ldr r3, [r10], r8  @load 4 uint8 values
str r3, [sp, #464]  @write on the stack
vld1.32 {d4[0]}, [r7:32] @load again from the stack instead of loading here directly

This also doesn't solve the other problem of the inefficient register allocation in LLVM, which comes out with certain parameter settings in the algorithm.

I found that getting LLVM to emit the desired assembly instructions is very tedious, and when one problem is solved another is introduced. That is why I think it would be much easier, and would make Halide much more flexible, if custom LLVM IR (possibly with inline assembly) could be defined for specific nodes in the Halide IR, after all the lowering passes are done. I would be happy to work on adding that support to Halide.

Thanks!

abadams commented 6 years ago

I really can't think of any reasonable way to do that, unfortunately. I'll keep thinking about it. In the mean time I would like to understand a little better a case where LLVM is doing a bad job of register allocation. Maybe we can give it a little more metadata to tell it more clearly what the hot loops are or something. Do you have a small example pipeline that trips it up?

On Fri, Jul 20, 2018 at 9:12 AM, Mohamed Aly notifications@github.com wrote:

Yep, I was able to get the scalar operands working with vmlal after some acrobatics with Halide. But that caused another problem. Basically, I forced the computation of a vector with compute_at that is then shuffleed to extract one element that is then broadcast, which is the pattern expected by LLVM. However, the generated code is quite weird in that in the inner kernel the vector of 4 uint8 values is read, then stored on the stack, then read again into the neon registers, and the computation is performed with scalar operands as desired. I put MemoryType::Register but it doesn't seem to be affecting this.

ldr r3, [r10], r8 @load 4 uint8 valuesstr r3, [sp, #464] @write on the stackvld1.32 {d4[0]}, [r7:32] @load again from the stack instead of loading here directly

This also doesn't solve the other problem of the inefficient register allocation in LLVM, which comes out with certain parameter settings in the algorithm.

I found that getting LLVM to emit the desired assembly instructions is very tedious, and when one problem is solved another is introduced. That is why I think it would be much easier, and would make Halide much more flexible, if custom LLVM IR (possibly with inline assembly) could be defined for specific nodes in the Halide IR, after all the lowering passes are done. I would be happy to work on adding that support to Halide.

Thanks!

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/halide/Halide/issues/3136#issuecomment-406648542, or mute the thread https://github.com/notifications/unsubscribe-auth/AAfdRgVCwxBdG_E3nMsE52xurAou2xksks5uIgFZgaJpZM4VVAIX .

mohamedadaly commented 6 years ago

Sure, here is an example (includes the stmt html, LLVM IR, and arm32 assembly). The inner loop (in consume right_u16_in_elwise) computes the outer product of a 12x1 vector with a 1x4 vector and accumulates into a 12x4 matrix. All the operands should easily fit in the 16 q registers, but two spill on the stack.

Also, with some tweaks (e.g. vectorize by 4 instead of 12) it can use scalar operands with vmlal, but that brings out the other problem of loading the values to the stack before reading them again. Also, when I vectorize by 4 and unroll by 3, it doesn't use scalar operands with vmlal any more :) It's really frustrating that such simple changes could have a big impact on the code generated.

Anyway, thanks for your help~

abadams commented 6 years ago

Not sure how this differs from your code (I assumed some things about layout that made life simpler), but by explicitly staging the 4-wide vector using Func::in() I get pretty good register usage:

#include <Halide.h>

using namespace Halide;

int main(int argc, char** argv) {
    ImageParam A(UInt(16), 2), B(UInt(16), 2);

    Func C;

    Var i, j;

    RDom k(0, 1024);

    // Assume A stored or staged col-major. C and B are row-major.
    C(j, i) += cast<uint32_t>(A(i, k)) * B(j, k);

    Var ii, ji;
    C.in().tile(j, i, ji, ii, 12, 4).vectorize(ji, 4).unroll(ji).unroll(ii);
    C.compute_at(C.in(), j).vectorize(j, 4).unroll(i).unroll(j)
        .update().reorder(j, i, k).vectorize(j, 4).unroll(i).unroll(j);

    A.in().compute_at(C, k).vectorize(_0);

    C.in().compile_to_assembly("/dev/stdout", {A, B}, Target("arm-linux-32-no_runtime-no_asserts-no_bounds_query"));
    return 0;
}

produces

.LBB0_5: 
    sub r7, r5, #8
    sub r5, r5, #16
    vld1.16 {d0}, [r8], r0
    subs    r6, r6, #1
    vld1.16 {d1}, [r10], r12
    vld1.16 {d11}, [r5]
    vld1.16 {d10}, [r7]
    mov r5, r10
    vmlal.u16   q8, d1, d0[3]
    vmlal.u16   q11, d1, d0[2]
    vmlal.u16   q14, d1, d0[1]
    vmlal.u16   q1, d1, d0[0]
    vmlal.u16   q9, d10, d0[3]
    vmlal.u16   q10, d11, d0[3]
    vmlal.u16   q12, d10, d0[2]
    vmlal.u16   q13, d11, d0[2]
    vmlal.u16   q15, d10, d0[1]
    vmlal.u16   q4, d11, d0[1]
    vmlal.u16   q2, d10, d0[0]
    vmlal.u16   q3, d11, d0[0]
    bne .LBB0_5
mohamedadaly commented 6 years ago

I don't quite get the same output. This is what I get with the lastest github master, with Halide compiled with LLVM 6.0. The register use is fine here, but some vmlal instructions don't use scalar operands:

.LBB0_5:                                @ %"for f2.s1.r4$x"
                                        @   Parent Loop BB0_2 Depth=1
                                        @     Parent Loop BB0_4 Depth=2
                                        @ =>    This Inner Loop Header: Depth=3
        vld1.16 {d0}, [r10], r0
        sub     r7, r6, #8
        vld1.16 {d1}, [r11], lr
        sub     r6, r6, #16
        vmlal.u16       q8, d1, d0[3]
        subs    r12, r12, #1
        vmlal.u16       q11, d1, d0[2]
        vmlal.u16       q14, d1, d0[1]
        vmlal.u16       q1, d1, d0[0]
        vdup.16 d1, d0[3]
        vld1.16 {d10}, [r7]
        vmlal.u16       q9, d10, d0[3]
        vmlal.u16       q12, d10, d0[2]
        vmlal.u16       q15, d10, d0[1]
        vmlal.u16       q2, d10, d0[0]
        vld1.16 {d10}, [r6]
        mov     r6, r11
        vmlal.u16       q10, d1, d10
        vdup.16 d1, d0[2]
        vmlal.u16       q13, d1, d10
        vdup.16 d1, d0[1]
        vdup.16 d0, d0[0]
        vmlal.u16       q4, d1, d10
        vmlal.u16       q3, d0, d10
        bne     .LBB0_5

and the example was compiled like so

g++ x.cpp -o x -I include/ -std=c++11 -lHalide -Llib/ -ldl -lpthread -lz -ltinfo -O3

Am I doing something wrong?

The input matrices are actually uint8, and when I change C accordingly

C(j, i) += cast<uint32_t>(cast<uint16_t>(A(i, k)) * cast<uint16_t>(B(j, k)));

The generated assembly becomes crazy

.LBB0_5:                                @ %"for f2.s1.r4$x"
                                        @   Parent Loop BB0_2 Depth=1
                                        @     Parent Loop BB0_4 Depth=2
                                        @ =>    This Inner Loop Header: Depth=3
        add     r2, sp, #112
        ldr     r4, [r6]
        vstmia  r2, {d20, d21}          @ 16-byte Spill
        add     r2, sp, #128
        subs    r3, r3, #1
        vstmia  r2, {d22, d23}          @ 16-byte Spill
        add     r2, sp, #144
        vstmia  r2, {d16, d17}          @ 16-byte Spill
        ldr     r5, [r6, #-8]
        ldr     r2, [r6, #-4]
        add     r6, r6, r0
        str     r4, [sp, #188]
        vld1.32 {d10[0]}, [r8:32]
        ldr     r4, [r7], r1
        str     r5, [sp, #196]
        vmovl.u8        q9, d10
        vld1.32 {d12[0]}, [r9:32]
        str     r2, [sp, #192]
        add     r2, sp, #96
        vmov.i32        q11, #0xff
        vorr    q5, q4, q4
        vorr    q4, q3, q3
        vorr    q3, q14, q14
        vorr    q14, q2, q2
        vorr    q2, q1, q1
        vorr    q1, q0, q0
        vstmia  r2, {d18, d19}          @ 16-byte Spill
        lsr     r2, r4, #16
        vdup.16 d16, r2
        lsr     r2, r4, #8
        vmovl.u16       q8, d16
        vmovl.u16       q9, d18
        vdup.16 d20, r2
        add     r2, sp, #160
        vand    q8, q8, q11
        vld1.32 {d13[0]}, [r10:32]
        vmovl.u16       q10, d20
        vmovl.u8        q7, d13
        vmovl.u8        q6, d12
        vmovl.u16       q11, d14
        vmovl.u16       q12, d12
        vldmia  r2, {d0, d1}            @ 16-byte Reload
        add     r2, sp, #160
        vmla.i32        q0, q8, q9
        vstmia  r2, {d0, d1}            @ 16-byte Spill
        add     r2, sp, #112
        vorr    q0, q1, q1
        vmov.i32        q1, #0xff
        vmla.i32        q0, q8, q11
        vand    q10, q10, q1
        vorr    q1, q2, q2
        vorr    q2, q14, q14
        vorr    q14, q3, q3
        vorr    q3, q4, q4
        vorr    q4, q5, q5
        vmla.i32        q4, q8, q12
        vdup.16 d16, r4
        vmovl.u16       q8, d16
        vmla.i32        q15, q10, q9
        vmla.i32        q1, q10, q11
        vmla.i32        q2, q10, q12
        vmov.i32        q10, #0xff
        vand    q8, q8, q10
        vmla.i32        q3, q8, q11
        vmla.i32        q14, q8, q9
        vmla.i32        q13, q8, q12
        vldmia  r2, {d20, d21}          @ 16-byte Reload
        add     r2, sp, #128
        vldmia  r2, {d22, d23}          @ 16-byte Reload
        lsr     r2, r4, #24
        vdup.16 d18, r2
        add     r2, sp, #96
        vmlal.u16       q11, d18, d14
        vldmia  r2, {d10, d11}          @ 16-byte Reload
        add     r2, sp, #144
        vldmia  r2, {d16, d17}          @ 16-byte Reload
        vmlal.u16       q10, d18, d10
        vmlal.u16       q8, d18, d12
        bne     .LBB0_5
abadams commented 6 years ago

I was using llvm 5. With 6 I get the same output as you :(

abadams commented 6 years ago

I'll try llvm trunk. We may have to open an llvm bug.

abadams commented 6 years ago

For the 8-bit case, to get a reasonable result I had to drop down to a 8x4 block of accumulators. It's still not great - there are dups on the args to the widening multiply, and for some reason llvm is sticking it in a gpr and then loading one byte at a time out of it.

Can you just call out to something like gemmlowp using Func::define_extern? 8-bit gemms are pretty well optimized by hand.

#include <Halide.h>

using namespace Halide;

int main(int argc, char** argv) {
    ImageParam A(UInt(8), 2), B(UInt(8), 2);

    Func C;

    Var i, j, k;

    Func A16;
    A16(i, k) = cast<uint16_t>(A(i, k));

    Func prod;
    prod(j, i, k) = A16(i, k) * B(j, k);

    RDom r(0, 1024);
    C(j, i) += cast<uint32_t>(prod(j, i, r));

    Var ii, ji;
    C.in().tile(j, i, ji, ii, 8, 4).vectorize(ji, 4).unroll(ji).unroll(ii);
    C.compute_at(C.in(), j).vectorize(j, 4).unroll(i).unroll(j)
        .update().reorder(j, i, r).vectorize(j, 4).unroll(i).unroll(j);
    prod.compute_at(C, i).vectorize(j, 8).unroll(j).unroll(i).unroll(k);

    A.in().compute_at(C, r).vectorize(_0);

    C.in().compile_to_assembly("/dev/stdout", {A, B}, Target("arm-linux-32-no_runtime-no_asserts-no_bounds_query"));
    return 0;
}
        ldr     r0, [r9], r7
        subs    r10, r10, #1
        vld1.8  {d0}, [r12], r1
        lsr     lr, r0, #16
        vdup.8  d1, r0
        lsr     r5, r0, #8
        lsr     r0, r0, #24
        vmull.u8        q1, d1, d0
        vdup.8  d1, r0
        vmull.u8        q3, d1, d0
        vdup.8  d4, lr
        vdup.8  d1, r5
        vmull.u8        q2, d4, d0
        vmull.u8        q0, d1, d0
        vaddw.u16       q14, q14, d3
        vaddw.u16       q15, q15, d2
        vaddw.u16       q8, q8, d7
        vaddw.u16       q9, q9, d6
        vaddw.u16       q10, q10, d5
        vaddw.u16       q11, q11, d4
        vaddw.u16       q12, q12, d1
        vaddw.u16       q13, q13, d0
        bne     .LBB0_5
mohamedadaly commented 6 years ago

@abadams Yes, I have been trying several things the past few weeks trying to get it to emit the right instructions. But LLVM seems a bit finicky and very sensitive to how different Func's are stored, computed at, etc.

I know it might be easier to just make a call to gemmlowp or something, but I am basically trying to replicate Gemmlowp with Halide. I did something similar with floating point data, and the result was a lot faster than using e.g. Eigen, with careful parameter tuning for specific input matrices sizes. I believe I can get better results using the same approach for uint8 inputs, but so far the hand-written assembly in gemmlowp is faster than the code generated by LLVM, which is sub-optimal.

mohamedadaly commented 6 years ago

So let's say I want to write a custom implementation for this Producer block that calls a custom LLVM IR function (during the ARM codegen):

      produce C {
        C[ramp(0, 1, 4)] = x4((uint32)0)
        C[ramp(4, 1, 4)] = x4((uint32)0)
        C[ramp(8, 1, 4)] = x4((uint32)0)
        C[ramp(12, 1, 4)] = x4((uint32)0)
        C[ramp(16, 1, 4)] = x4((uint32)0)
        C[ramp(20, 1, 4)] = x4((uint32)0)
        C[ramp(24, 1, 4)] = x4((uint32)0)
        C[ramp(28, 1, 4)] = x4((uint32)0)
        C[ramp(32, 1, 4)] = x4((uint32)0)
        C[ramp(36, 1, 4)] = x4((uint32)0)
        C[ramp(40, 1, 4)] = x4((uint32)0)
        C[ramp(44, 1, 4)] = x4((uint32)0)
        let t43 = (C_global_wrapper.s0.j.ji.base - t19)
        for (C.s1.k$x, 0, 1024) {
          allocate A_im_global_wrapper[uint16 * 4]
          produce A_im_global_wrapper {
            A_im_global_wrapper[ramp(0, 1, 4)] = A[ramp((t21 + (C.s1.k$x*A.stride.1)), 1, 4)]
          }
          C[ramp(0, 1, 4)] = (C[ramp(0, 1, 4)] + (x4(uint32(A_im_global_wrapper[0]))*uint32x4(B[ramp((t43 + (C.s1.k$x*B.stride.1)), 1, 4)])))
          C[ramp(4, 1, 4)] = (C[ramp(4, 1, 4)] + (x4(uint32(A_im_global_wrapper[0]))*uint32x4(B[ramp(((t43 + (C.s1.k$x*B.stride.1)) + 4), 1, 4)])))
          C[ramp(8, 1, 4)] = (C[ramp(8, 1, 4)] + (x4(uint32(A_im_global_wrapper[0]))*uint32x4(B[ramp(((t43 + (C.s1.k$x*B.stride.1)) + 8), 1, 4)])))
          C[ramp(12, 1, 4)] = (C[ramp(12, 1, 4)] + (x4(uint32(A_im_global_wrapper[1]))*uint32x4(B[ramp((t43 + (C.s1.k$x*B.stride.1)), 1, 4)])))
          C[ramp(16, 1, 4)] = (C[ramp(16, 1, 4)] + (x4(uint32(A_im_global_wrapper[1]))*uint32x4(B[ramp(((t43 + (C.s1.k$x*B.stride.1)) + 4), 1, 4)])))
          C[ramp(20, 1, 4)] = (C[ramp(20, 1, 4)] + (x4(uint32(A_im_global_wrapper[1]))*uint32x4(B[ramp(((t43 + (C.s1.k$x*B.stride.1)) + 8), 1, 4)])))
          C[ramp(24, 1, 4)] = (C[ramp(24, 1, 4)] + (x4(uint32(A_im_global_wrapper[2]))*uint32x4(B[ramp((t43 + (C.s1.k$x*B.stride.1)), 1, 4)])))
          C[ramp(28, 1, 4)] = (C[ramp(28, 1, 4)] + (x4(uint32(A_im_global_wrapper[2]))*uint32x4(B[ramp(((t43 + (C.s1.k$x*B.stride.1)) + 4), 1, 4)])))
          C[ramp(32, 1, 4)] = (C[ramp(32, 1, 4)] + (x4(uint32(A_im_global_wrapper[2]))*uint32x4(B[ramp(((t43 + (C.s1.k$x*B.stride.1)) + 8), 1, 4)])))
          C[ramp(36, 1, 4)] = (C[ramp(36, 1, 4)] + (x4(uint32(A_im_global_wrapper[3]))*uint32x4(B[ramp((t43 + (C.s1.k$x*B.stride.1)), 1, 4)])))
          C[ramp(40, 1, 4)] = (C[ramp(40, 1, 4)] + (x4(uint32(A_im_global_wrapper[3]))*uint32x4(B[ramp(((t43 + (C.s1.k$x*B.stride.1)) + 4), 1, 4)])))
          C[ramp(44, 1, 4)] = (C[ramp(44, 1, 4)] + (x4(uint32(A_im_global_wrapper[3]))*uint32x4(B[ramp(((t43 + (C.s1.k$x*B.stride.1)) + 8), 1, 4)])))
          free A_im_global_wrapper
        }
      }

The way I understand so far, from looking at the debug codegen output, is that this LLVM IR goes through a bunch of optimization passes, and one of them would promote the C variable to be in registers instead of on the stack. At the codegen stage (when CodeGen_ARM gets called) the information that C is in registers is not there yet, is that correct? If the custom IR routine uses inline assembly to have an optimized implementation of this for loop, is there a way to make the subsequent code use these registers? Or is there a way to manually promote C to registers before calling the custom IR routine? Or I don't understand the whole pipeline for the code generation? :)

abadams commented 6 years ago

I've been looking into it too, and the best I can figure, the only way we can actually influence register allocation is if we wrote a custom llvm register allocator and somehow swapped that in (e.g. by making a custom subtarget of llvm's arm target). Not sure if llvm is even extensible to that extent.

But the reason that llvm's register allocator is failing if that it wants to reorder the instructions (to account for instruction latencies) in a way in which it ends up needing more registers than it has. The instruction reordering is something we can control to some extent, via inserting barriers to code motion (empty asm volatile blocks prevent code motion). We use that trick in the profiler. I'm experimenting with that now for the gemm inner loop. With 16-bit inputs it produces:

        sub     r1, r10, #8
        @APP
        @NO_APP
        vld1.16 {d0}, [r2], lr
        @APP
        @NO_APP
        subs    r6, r6, #1
        vld1.16 {d10}, [r1]
        sub     r1, r10, #16
        vmlal.u16       q9, d10, d0[3]
        vld1.16 {d1}, [r5], r12
        vmlal.u16       q12, d10, d0[2]
        vmlal.u16       q8, d1, d0[3]
        vld1.16 {d11}, [r1]
        vmlal.u16       q11, d1, d0[2]
        vmlal.u16       q14, d1, d0[1]
        mov     r10, r5
        vmlal.u16       q1, d1, d0[0]
        vmlal.u16       q10, d11, d0[3]
        vmlal.u16       q13, d11, d0[2]
        vmlal.u16       q15, d10, d0[1]
        vmlal.u16       q4, d11, d0[1]
        vmlal.u16       q2, d10, d0[0]
        vmlal.u16       q3, d11, d0[0]
        bne     .LBB3_7

@APP @NOAPP is a no-op that indicates where the barriers was inserted. It's not perfect but it's pretty good. It forced llvm to load a vector of A entirely before proceeding with the rest of the inner loop, which meant that all use of it comes out of the lanes instead of having vdup instructions later. I'll try for a reasonable 8-bit schedule now.

abadams commented 6 years ago

No luck with that approach. It may be necessary to write an extern stage in assembly that computes the full matrix multiply for a single 12x4 or maybe 8x6 tile. :(

mohamedadaly commented 6 years ago

What's the best way to do this?

abadams commented 6 years ago

Something like this:

#include <Halide.h>

using namespace Halide;

#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif

extern "C" DLLEXPORT int gemm_12x4(halide_buffer_t *A, halide_buffer_t *B, int K, halide_buffer_t *C) {
    // TODO: replace loop with inline assembly that computes the
    // single 12x4 tile of the output specified by
    // C.dim[*].{min|extent}
    printf("Computing output tile [%d..%d]x[%d..%d]\n",
           C->dim[0].min, C->dim[0].min + C->dim[0].extent - 1,
           C->dim[1].min, C->dim[1].min + C->dim[1].extent - 1);
    for (int k = 0; k < K; k++) {
        for (int j = 0; j < 12; j++) {
            for (int i = 0; i < 4; i++) {
                // Something involving A, B, and C
            }
        }
    }
    return 0;
}

int main(int argc, char** argv) {
    ImageParam A(UInt(8), 2), B(UInt(8), 2);

    Func C;

    Var i, j, ji, ii;

    Func single_tile;
    single_tile.define_extern("gemm_12x4", {A, B, B.dim(1).extent()}, UInt(32), 2);

    // Use a secret internal API to avoid having to do bounds
    // queries. Just needs to be an expression that has the same
    // footprint on the inputs as will actually be accessed.
    single_tile.function().extern_definition_proxy_expr() = A(_1, 0) * B(_0, 0) + A(_1, 1023) * B(_0, 1023);    

    C(j, i) = single_tile(j, i);
    C.tile(j, i, ji, ii, 12, 4).vectorize(ji, 4).unroll(ji).unroll(ii).parallel(i, 4);
    single_tile.compute_at(C, j);    

    A.set(Buffer<uint8_t>(1024, 1024));
    B.set(Buffer<uint8_t>(1024, 1024));    

    C.realize(1024, 1024);

    return 0;
}

The problem is that the extern stage isn't going to get inlined, so you need to do enough work in it to amortize the function call overhead. The entire loop over k is sufficient.

mohamedadaly commented 6 years ago

Great. Thanks a lot! I will play with that and see how it goes ...

mohamedadaly commented 6 years ago

Is there a specific way to put this in a generator e.g. define the extern function in a different object file, and then link with the binary that's using the generator? Specific compilation flags?

I keep getting errors related to Halide::Internal::check_introspection:

extern.o:extern.cpp:function Halide::Internal::check_introspection(void const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, int) [clone .constprop.29]: error: undefined reference to 'Halide::Internal::Introspection::get_source_location[abi:cxx11]()'
mohamedadaly commented 6 years ago

Never mind. I was including Halide.h instead of HalideRuntime.h...

mohamedadaly commented 6 years ago

Another quick question:

I am trying to stage A by splitting it into horizontal blocks using A.in(single_tile), but that doesn't seem to have any effect on the input to gemm_12x4 which is passed the whole A buffer each time (not the horizontal slice as intended). Is there a trick to get this to work?

#include <Halide.h>

using namespace Halide;

#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif

extern "C" DLLEXPORT int gemm_12x4(halide_buffer_t *A, halide_buffer_t *B, int K, halide_buffer_t *C) {
    // TODO: replace loop with inline assembly that computes the
    // single 12x4 tile of the output specified by
    // C.dim[*].{min|extent}
    printf("Computing output tile [%d..%d]x[%d..%d]\n",
           C->dim[0].min, C->dim[0].min + C->dim[0].extent - 1,
           C->dim[1].min, C->dim[1].min + C->dim[1].extent - 1);
    printf("A [%d..%d]x[%d..%d]\n",
           A->dim[0].min, A->dim[0].min + A->dim[0].extent - 1,
           A->dim[1].min, A->dim[1].min + A->dim[1].extent - 1);
    printf("B [%d..%d]x[%d..%d]\n",
           B->dim[0].min, B->dim[0].min + B->dim[0].extent - 1,
           B->dim[1].min, B->dim[1].min + B->dim[1].extent - 1);
    for (int k = 0; k < K; k++) {
        for (int j = 0; j < 12; j++) {
            for (int i = 0; i < 4; i++) {
                // Something involving A, B, and C
            }
        }
    }
    return 0;
}

int main(int argc, char** argv) {
    ImageParam A_(UInt(8), 2), B_(UInt(8), 2);

    Var i, j, ji, ii;

    Func C;
    Func A, B;

    A(i, j) = A_(i, j);
    B(j, i) = B_(j, i);

    A_.dim(0).set_min(0);

    Func single_tile;
    single_tile.define_extern("gemm_12x4", {A, B, B_.dim(1).extent()}, UInt(32), 2);

    // Use a secret internal API to avoid having to do bounds
    // queries. Just needs to be an expression that has the same
    // footprint on the inputs as will actually be accessed.
    single_tile.function().extern_definition_proxy_expr() = A(_1, 0) * B(_0, 0) + A(_1, 1023) * B(_0, 1023);    

    C(j, i) = single_tile(j, i);
    C.tile(j, i, ji, ii, 12, 4).vectorize(ji, 4).unroll(ji).unroll(ii); //.parallel(i, 4);
    single_tile.compute_at(C, j);    

    //A.split(i, i, ii, 12).compute_at(C, j);
    A.in().compute_at(C, j).bound_extent(i, 12).bound_extent(j, 1024);

    A_.set(Buffer<uint8_t>(1024, 1024));
    B_.set(Buffer<uint8_t>(1024, 1024));    

    C.realize(1024, 1024);

    return 0;
}

produces

Computing output tile [0..11]x[0..3]
A [0..1023]x[0..1023]
B [0..1023]x[0..1023]
Computing output tile [12..23]x[0..3]
A [0..1023]x[0..1023]
B [0..1023]x[0..1023]
Computing output tile [24..35]x[0..3]
A [0..1023]x[0..1023]
B [0..1023]x[0..1023]
Computing output tile [36..47]x[0..3]
A [0..1023]x[0..1023]
B [0..1023]x[0..1023]
Computing output tile [48..59]x[0..3]
A [0..1023]x[0..1023]
B [0..1023]x[0..1023]
...
abadams commented 6 years ago

Hrm, maybe the redirection to the wrapper isn't happening properly for extern calls. Try replacing A in the argument to define_extern with A.in(single_tile).

mohamedadaly commented 6 years ago

Yep, that worked! Thanks!

abadams commented 6 years ago

@psuriana looks like Func::in isn't working for Func args to extern stages