halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.84k stars 1.07k forks source link

Suboptimal hoisting of loop invariants #6073

Open steven-johnson opened 3 years ago

steven-johnson commented 3 years ago

Consider the fragment of a rotate-and-crop pipeline seen here:

#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {

    ImageParam input(UInt(8), 2);
    Param<float> rot_center_x;
    Param<float> rot_center_y;
    Param<float> rot_radians;
    Param<float> roi_y_min;
    Param<float> roi_y_max;
    Param<float> roi_x_min;
    Param<float> roi_x_max;
    Param<int> dst_width;
    Param<int> dst_height;

    Var x("x"), y("y");

    Func clamped = Halide::BoundaryConditions::repeat_edge(input);

    Expr sin_theta = Halide::sin(rot_radians);
    Expr cos_theta = Halide::cos(rot_radians);

    Expr coef_xx_f = input.width() * cos_theta / dst_width;
    Expr coef_xy_f = -input.width() * sin_theta / dst_height;
    Expr coef_yx_f = (roi_x_max - roi_x_min) * sin_theta / dst_width;
    Expr coef_yy_f = (roi_y_max - roi_y_min) * cos_theta / dst_height;

    Expr coef_xx = cast<int>(Halide::round(coef_xx_f * 65536));
    Expr coef_xy = cast<int>(Halide::round(coef_xy_f * 65536));
    Expr coef_yx = cast<int>(Halide::round(coef_yx_f * 65536));
    Expr coef_yy = cast<int>(Halide::round(coef_yy_f * 65536));

    Expr x_offset_f = (coef_xx_f + coef_xy_f) * 0.5f - 0.5f;
    Expr y_offset_f = (coef_yx_f + coef_yy_f) * 0.5f - 0.5f;

    Expr x_const = (roi_x_min * cos_theta -
                    roi_y_min * sin_theta -
                    rot_center_x * cos_theta +
                    rot_center_y * sin_theta +
                    rot_center_x +
                    x_offset_f);
    Expr y_const = (roi_x_min * sin_theta +
                    roi_y_min * cos_theta -
                    rot_center_x * sin_theta -
                    rot_center_y * cos_theta +
                    rot_center_y +
                    y_offset_f);

    Expr x_const_int = cast<int>(Halide::round(x_const * 65536));
    Expr y_const_int = cast<int>(Halide::round(y_const * 65536));

    Expr x0 = (coef_xx * x + coef_xy * y + x_const_int) / 65536;
    Expr y0 = (coef_yx * x + coef_yy * y + y_const_int) / 65536;

    Func output;
    output(x, y) = clamped(x0, y0);

    constexpr int kEdge = 128;

    output
        .bound(x, 0, kEdge)
        .bound(y, 0, kEdge)
        .vectorize(x, 8, TailStrategy::RoundUp)
        .compute_root();

    Buffer<uint8_t> buf(kEdge, kEdge);
    buf.fill(0);

    input.set(buf);

    // Param values arbitrary
    rot_center_x.set(kEdge/2);
    rot_center_y.set(kEdge/2);
    rot_radians.set(3.1415f / 2);
    roi_y_min.set(kEdge/4);
    roi_y_max.set(3*kEdge/4);
    roi_x_min.set(kEdge/4);
    roi_x_max.set(3*kEdge/4);
    dst_width.set(kEdge);
    dst_height.set(kEdge);

    auto result = output.realize({kEdge, kEdge});

    return 0;
}

If you execute this with HL_DEBUG_CODEGEN=1, you'll get something like:

 produce output {
  let t197 = (float32)sin_f32((float32)rot_radians)
  let t194 = (float32)cos_f32((float32)rot_radians)
  let t201 = (float32)roi_y_max - (float32)roi_y_min
  let t202 = (float32)roi_x_min - (float32)rot_center_x
  let t200 = (float32)roi_x_max - (float32)roi_x_min
  let t204 = (input.min.1*input.stride.1) + input.min.0
  let t206 = input.extent.1 + input.min.1
  let t203 = input.extent.0 + input.min.0
  let t198 = float32((0 - input.extent.0))
  let t195 = float32(input.extent.0)
  let t196 = float32(dst_width)
  let t199 = float32(dst_height)
  for (output.s0.y, 0, 128) {
   let t209 = (t197*t200)/t196
   let t208 = (t197*t198)/t199
   let t210 = (t194*t201)/t199
   let t207 = (t194*t195)/t196
   let t213 = output.s0.y*output.stride.1
   let t212 = ((t197*t202) + (((float32)roi_y_min - (float32)rot_center_y)*t194)) + (float32)rot_center_y
   let t211 = (((float32)rot_center_y*t197) + ((t194*t202) - ((float32)roi_y_min*t197))) + (float32)rot_center_x
   for (output.s0.x.x, 0, 16) {
    let t148 = int32((float32)round_f32(t207*65536.000000f))
    let t153 = int32((float32)round_f32(t209*65536.000000f))
    output[ramp((output.s0.x.x*8) + t213, 1, 8)] = input[(max(min((int32x8)shift_right(ramp(((output.s0.x.x*t153)*8) + ((output.s0.y*int32((float32)round_f32(t210*65536.000000f))) + int32((float32)round_f32(((((t209 + t210)*0.500000f) + t212)*65536.000000f) + -32768.000000f))), t153, 8), x8((uint32)16)), x8(t206 + -1)), x8(input.min.1))*x8(input.stride.1)) + (max(min((int32x8)shift_right(ramp(((output.s0.x.x*t148)*8) + ((output.s0.y*int32((float32)round_f32(t208*65536.000000f))) + int32((float32)round_f32(((((t207 + t208)*0.500000f) + t211)*65536.000000f) + -32768.000000f))), t148, 8), x8((uint32)16)), x8(t203 + -1)), x8(input.min.0)) - x8(t204))]
   }
  }
 }
}

If you examine this closely, there are things that could be hoisted out of both loops, but are not. e.g., t209 is loop-invariant but is recalculated inside the y loop. t153 is also loop-invariant but is recalculated inside the inner loop.

The cause of this is not clear to me from brief experimentation -- it appears that the logic in LiftLoopInvariants::visit_lets is deliberately conservative -- but there's clearly room for improvement here. (@abadams suggests that perhaps letstmts that don't depend on the loop variable should be lifted in a prepass, which might be an easy experiment for someone to try)

rootjalex commented 3 years ago

The reason t209 isn't hoisted is because loop invariants are currently hoisted top-down instead of bottom-up, which is a relatively easy fix (that hopefully shouldn't break anything). The reason t153 isn't hoisted is a bit trickier (or just requires a more involved fix) - currently, if a LetStmt or Let is hoisted, the variables that depend on it still view that variable as varying with the current For loop.

A prepass doesn't work that well, because LICM mucks things up by introducing a bunch of LetStmts in the wrong place. I have a possible fix that generates the following IR:

 produce output {
  let t197 = (float32)sin_f32((float32)rot_radians)
  let t194 = (float32)cos_f32((float32)rot_radians)
  let t201 = (float32)roi_y_max - (float32)roi_y_min
  let t202 = (float32)roi_x_min - (float32)rot_center_x
  let t200 = (float32)roi_x_max - (float32)roi_x_min
  let t204 = (input.min.1*input.stride.1) + input.min.0
  let t206 = input.extent.1 + input.min.1
  let t203 = input.extent.0 + input.min.0
  let t198 = float32((0 - input.extent.0))
  let t195 = float32(input.extent.0)
  let t196 = float32(dst_width)
  let t199 = float32(dst_height)
  let t147 = (t194*t195)/t196
  let t148 = int32((float32)round_f32(t147*65536.000000f))
  let t150 = (t197*t198)/t199
  let t152 = (t197*t200)/t196
  let t153 = int32((float32)round_f32(t152*65536.000000f))
  let t154 = (t194*t201)/t199
  for (output.s0.y, 0, 128) {
   let t209 = output.s0.y*output.stride.1
   let t208 = (output.s0.y*int32((float32)round_f32(t154*65536.000000f))) + int32((float32)round_f32(((((t152 + t154)*0.500000f) + (((t197*t202) + (((float32)roi_y_min - (float32)rot_center_y)*t194)) + (float32)rot_center_y))*65536.000000f) + -32768.000000f))
   let t207 = (output.s0.y*int32((float32)round_f32(t150*65536.000000f))) + int32((float32)round_f32(((((t147 + t150)*0.500000f) + ((((float32)rot_center_y*t197) + ((t194*t202) - ((float32)roi_y_min*t197))) + (float32)rot_center_x))*65536.000000f) + -32768.000000f))
   for (output.s0.x.x, 0, 16) {
    output[ramp((output.s0.x.x*8) + t209, 1, 8)] = input[(max(min((int32x8)shift_right(ramp(((output.s0.x.x*t153)*8) + t208, t153, 8), x8((uint32)16)), x8(t206 + -1)), x8(input.min.1))*x8(input.stride.1)) + (max(min((int32x8)shift_right(ramp(((output.s0.x.x*t148)*8) + t207, t148, 8), x8((uint32)16)), x8(t203 + -1)), x8(input.min.0)) - x8(t204))]
   }
  }
 }
}

But I am hesitant to introduce it as a PR because it's an ugly fix - it requires running the LetStmt lifter both before and after LICM runs, when I think LICM should be fixed. Would appreciate feedback on that

rootjalex commented 3 years ago

Oh hmm, that actually seems to have removed some of the invariants that LICM should have gathered somehow...

JoelLinn commented 1 year ago

Imho, the following code should automatically hoist the "invalid" Func to the y loop. I found a case in a much more complex pipeline where LLVM fails to hoist a similar expression. Plus the statements become very hard to read, especially when unrolling is used. Of course with unrolling, common subexpression elimination is probably also needed. Manual hoisting with compute_at has some downsides too, for example it will generate unreasonable overhead (function call) when using the -profile target (imagine hoisting out of a dimension with only an extent of 4).

#include "Halide.h"

using namespace Halide;
using namespace Halide::ConciseCasts;

class HoistTestGenerator : public Halide::Generator<HoistTestGenerator> {
public:
    Input<int32_t> numRows{"numRows"}, numCols{"numCols"};

    Input<Buffer<uint8_t, 1>> flags{"flags"};
    Input<Buffer<uint16_t, 1>> input1{"input1"};
    Input<Buffer<uint16_t, 2>> input2{"input2"};

    Output<Buffer<uint16_t, 2>> output{"output"};

    GeneratorParam<bool> manualHoist{"manualHoist", false};

    Func invalid{"invalid"};

    Var x{"x"}, y{"y"};

    void generate() {
        invalid(y) = flags(y) == 64;
        output(x, y) = select(invalid(y), 0, input1(x) * input2(x, y));
    }

    void schedule() {
        Expr nR = numRows;
        Expr nC = numCols;
        flags .dim(0).set_bounds(0, nR).set_stride(1);
        input1.dim(0).set_bounds(0, nC).set_stride(1);
        input2.dim(0).set_bounds(0, nC).set_stride(1);
        input2.dim(1).set_bounds(0, nR).set_stride(nC);
        output.dim(0).set_bounds(0, nC).set_stride(1);
        output.dim(1).set_bounds(0, nR).set_stride(nC);

        if (manualHoist) {
            invalid
                .compute_at(output, y)
                .store_in(MemoryType::Register)
            ;
        }
    }
};

HALIDE_REGISTER_GENERATOR(HoistTestGenerator, hoist_test_generator)

manualHoist=false

produce output {
 for (output.s0.y, 0, numRows) {
  for (output.s0.x, 0, numCols) {
   let t5 = (numCols*output.s0.y) + output.s0.x
   output[t5] = select(flags[output.s0.y] == (uint8)64, (uint16)0, input1[output.s0.x]*input2[t5])
  }
 }
}

manualHoist=true

produce output {
 for (output.s0.y, 0, numRows) {
  allocate invalid[uint8 * 1] in Register
  produce invalid {
   invalid[0] = uint8((flags[output.s0.y] == (uint8)64))
  }
  consume invalid {
   let t7 = numCols*output.s0.y
   for (output.s0.x, 0, numCols) {
    let t6 = output.s0.x + t7
    output[t6] = select(uint1(invalid[0]), (uint16)0, input1[output.s0.x]*input2[t6])
   }
  }
  free invalid
 }
}