halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.78k stars 1.07k forks source link

Bounds inference doesn't take `select()` condition into account. #8301

Open mcourteaux opened 1 week ago

mcourteaux commented 1 week ago
#include "Halide.h"
#include <stdio.h>
using namespace Halide;

int main(int argc, char **argv) {

    Buffer<float, 3> bufA({10, 10, 25}, "bufA");
    Buffer<float, 3> bufB({10, 10, 25}, "bufB");
    Buffer<float, 2> bufC({10, 10}, "bufC");
    Var x("x"), y("y"), i("i");

    Func concat("concat");
    concat(x, y, i) = select(
        i < 25, bufA(x, y, i),
        i < 50, bufB(x, y, i - 25),
                bufC(x, y));

    concat.realize({10, 10, 51});

    printf("Success!\n");
    return 0;
}

fails with:

Error: Input buffer bufA is accessed at 50, which is beyond the max (24) in dimension 2

Bringing me to the conclusion that it doesn't use the conditions in the select().

Equivalently, when you'd silcence this error in particular by just making the buffer bigger:

#include "Halide.h"
#include <stdio.h>
using namespace Halide;

int main(int argc, char **argv) {

    Buffer<float, 3> bufA({10, 10, 51}, "bufA"); // < change here!
    Buffer<float, 3> bufB({10, 10, 51}, "bufB"); // < change here!
    Buffer<float, 2> bufC({10, 10}, "bufC");
    Var x("x"), y("y"), i("i");

    Func concat("concat");
    concat(x, y, i) = select(
        i < 25, bufA(x, y, i),
        i < 50, bufB(x, y, i - 25),
                bufC(x, y));

    concat.realize({10, 10, 51});

    printf("Success!\n");
    return 0;
}

It now errors with:

Error: Input buffer bufB is accessed at -25, which is before the min (0) in dimension 2

demonstrating how it also doesn't use the negation of the previous cases of the select.

abadams commented 1 week ago

A select always evaluates both sides, so bounds inference is actually correct here. This does indeed make concatenation awkward. You have to do this:

    concat(x, y, i) = select(
        i < 25, bufA(x, y, min(i, 24)),
        i < 50, bufB(x, y, clamp(i - 25, 0, 24)),
                bufC(x, y));
mcourteaux commented 1 week ago

Ah, I see. Would it be feasible/sensible to add a variant of select() that behaves more like an if-then-else instead of a blend? I guess that would better align with what I was expecting what Halide would do. Not sure if this would be compatible with vectorization. I can imagine that it would be compatible with vectorization if it's always treated as guard-with-if, and then later gets optimized away if the vector size aligns with the size, or gets unrolled.

abadams commented 1 week ago

We could, yes. It would interfere with vectorization as you say, but not too badly if we make sure it just generates predicates on the loads done. Bounds inference wouldn't handle it very well, so you'd need unsafe_promise_clamped in places. I recently tried this for constant_exterior, but it didn't actually make it faster.

This actually already exists, but it's not exposed to the front-end. The way to try it out is to replace select(a, b, c) with Internal::Call::make(b.type(), Internal::Call::if_then_else, {a, b, c}, Internal::Call::Intrinsic);

mcourteaux commented 1 week ago

So yeah, the lesson here is to RTFM. :stuck_out_tongue: I wonder how much this if-then-else alternative can be beneficial, as the simplifier can then also insert the scoped_truths in each of the branches...