halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.88k stars 1.07k forks source link

Forcing register allocation #2934

Open mohamedadaly opened 6 years ago

mohamedadaly commented 6 years ago

I am trying to write a generator to perform a GEMM operation. The outline is below.

class GemmRowMajor : public Halide::Generator<GemmRowMajor> {
public:
  Input<Buffer<float>> lhs{"lhs", 2};
  Input<Buffer<float>> rhs{"rhs", 2};
  Input<Buffer<float>> bias{"bias", 1};

  Input<int> mc{"mc", 256};
  Input<int> nc{"nc", 256};
  Input<int> kc_{"kc", 128};

  Output<Buffer<float>> relu{"ReLU", 2};

  void generate() {

    Var c("c"), x("x"), y("y"), n("n"), c_out("c_out"), 
    rl("row_left"), rr("row_right"), cl("col_left"), cr("col_right"),
      p("pixel"), k;

    Expr rows = lhs.dim(1).extent();
    Expr cols = rhs.dim(0).extent();
    Expr depth = lhs.dim(0).extent();

    Expr left_rows = rows; 
    Expr left_cols = depth;
    Expr right_rows = depth;
    Expr right_cols = cols;

    constexpr int mr = 4;
    constexpr int nr = 8;

    // If not splitting depth: set kc = depth regardless of the input
    kc = depth;

    Func elwise;
    elwise(x, y, k) = lhs(k, y) * rhs(x, k);

    // The convolution is just the matrix multiplication
    RDom r(0, kc);
    // Conv has dimensions left_rows * out_depth
    Func conv("conv");
    conv(x, y, pc) = 0.f; 
    conv(x, y, pc) += elwise(x, y, r, pc);

   // Relu
   relu(x, y) = conv(x, y, 0) + bias(x);

      Var yi, xi, xii, yii, xiii, yiii, tile_l2, tile_l1,
          ci, cii, rowi, rowii, xc, xr, yc, yr;
      RVar r_out, r_in, ri, rii, riii;

      const int tile_size = 4;

      // Parallelize y (rows) such that it has about 2K pixels
      // Make sure it's between [1, rows]
      Expr y_chunk = min(left_rows, max(1, (1 << 11) / right_cols));

      lhs.in(elwise)
        .compute_at(conv, yc)
        .reorder(y, x)
        .reorder_storage(y, x)
        .vectorize(y, mr)
        ;

        rhs.in(elwise)
          .compute_at(conv, xc)
        .vectorize(x, nr)
        ;

    conv
      .compute_root()
          .parallel(y, y_chunk, TailStrategy::GuardWithIf)
          .vectorize(x, 8, TailStrategy::GuardWithIf)
        ;

   conv.update(0)
      .tile(x, y, xc, yc, x, y, nc, mc, TailStrategy::GuardWithIf)
      .tile(x, y, xr, yr, x, y, nr, mr, TailStrategy::GuardWithIf)
      .split(r, r, ri, kc)
      .reorder(x, y, ri, xr, yr, yc, xc, r)

      .unroll(y, mr)
      .vectorize(x, nr)

      .parallel(yc)
        ;

        relu
          .bound(y, 0, left_rows)
          .bound(x, 0, right_cols)

          .parallel(y, y_chunk, TailStrategy::GuardWithIf)
          .vectorize(x, 8, TailStrategy::GuardWithIf)
          ;

  }
}; // RowMajor

The problem with generated code is that the inner kernel (loop in conv inside the inner tile [xr, yr]) keeps loading/storing into memory for the accumulator. This wastes a lot of cpu cycles, and could be made much faster. For example, loading all the 4x4 elements of the conv array into registers outside this loop, and then multiplying-accumulating using vmla.f32 inside the loop would be much faster than loading/storing inside the loop.

Is there a way to do this using scheduling i.e. to force the inner loop to be stored in registers?

Thanks.


.LBB2_17:                               @ %"for conv.s1.r$x.ri.us"
                                        @   Parent Loop BB2_14 Depth=1
                                        @     Parent Loop BB2_16 Depth=2
                                        @ =>    This Inner Loop Header: Depth=3
    vld1.32 {d16, d17}, [r2]
    sub lr, r6, #12
    vld1.32 {d18[], d19[]}, [lr:32]
    sub r5, r6, #8
    vld1.32 {d20, d21}, [r4], r9
    subs    r12, r12, #1
    vmla.f32    q8, q10, q9
    vst1.32 {d16, d17}, [r2]
    vld1.32 {d16, d17}, [r7]
    vld1.32 {d18[], d19[]}, [r5:32]
    sub r5, r6, #4
    vmla.f32    q8, q10, q9
    vst1.32 {d16, d17}, [r7]
    vld1.32 {d16, d17}, [r3]
    vld1.32 {d18[], d19[]}, [r5:32]
    vmla.f32    q8, q10, q9
    vst1.32 {d16, d17}, [r3]
    vld1.32 {d16, d17}, [r1]
    vld1.32 {d18[], d19[]}, [r6:32], r0
    vmla.f32    q8, q10, q9
    vst1.32 {d16, d17}, [r1]
    bne .LBB2_17```
abadams commented 6 years ago

It's going back and forth to memory because conv is scheduled compute_root. If you schedule it compute_at tiles of relu instead, such that the size of conv needed for one tile is small enough to fit in registers, it should get promoted into registers. See apps/linear_algebra for example gemm schedules.

mohamedadaly commented 6 years ago

Thanks!

But the problem is that relu doesn't know about the RDom in conv (the reduction over the cols of lhs = rows of rhs), and I want to split over that dimension using this order

.reorder(x, y, ri, xr, yr, yc, xc, r)

basically to achieve a blocking/tiling structure like this.

Is there a trick to achieve this ordering and at the same time promote the inner loop to registers?

abadams commented 6 years ago

So you want to do some of the summation in the innermost loop, and more in the outermost loop. I think you need to factor that reduction into two stages, using Func::rfactor, so that ri and r belong to two distinct Funcs.

mohamedadaly commented 6 years ago

Func::rfactor did the trick. Thanks a lot :)

One more question: Is it possible to specialize the compute_at level of a Func? When I do that I get this error

conv.specialize(split).compute_root();

error: ‘class Halide::Stage’ has no member named ‘compute_root’

I can do it with a GeneratorParam at compile time to fork two separate paths, but it would be easier if there is a way to do it conditionally at run time.

abadams commented 6 years ago

You can, but it's a bit counterintuitive. You need to make a compute_at on a single variable that's valid on both sides of the specialization, but has a different meaning by virtue of where that variable shows up in the loop nest. E.g.

Var x, y, dummy;
consumer(x, y) = producer(x, y);
// Make a dummy var of size 1, and then conditionally reorder it to be outermost
consumer.split(x, x, dummy, 1);
consumer.specialize(c).reorder(x, y, dummy); // if c dummy is outermost loop
consumer.reorder(dummy, x, y); // else it's the innermost loop
producer.compute_at(consumer, dummy); 
mohamedadaly commented 6 years ago

Is there a similar trick to make it also work for bound_extent and store_in etc. that are not defined for Halide::Stage? Or more generally, is there a way to have conditional branching at run time, similar to specialize but that allow arbitrary statements?

Thanks again!