halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.86k stars 1.07k forks source link

We should revisit FuncValueBounds depending on the LHS of the Func #8272

Open abadams opened 3 months ago

abadams commented 3 months ago

Currently bounds inference acts differently if you inline a Func or not. Some things are only legal if inlined (e.g. #8261).

The cause is that the bounds on the value of a Func are precomputed, and do not depend on the LHS vars of that Func. So while a particular use of a Func may have bounds, because bounded args are passed to it, bounds inference fails to recognize that. We could fix this by using a symbolic interval here: https://github.com/halide/Halide/blob/main/src/Bounds.cpp#L3359 instead of an infinite one, and then binding those symbols with Lets wherever we use FuncValueBounds.

The concern is that this might create O(n^2) worth of bounds expressions. Cutting FuncValueBounds off at Func calls keeps the bounds expressions bounded in size. But it's causing problems, so maybe there's something clever we can do.

mcourteaux commented 1 month ago

Any plans with this? It keeps hitting me. I'm trying to load filter weights of a particular filter into GPUShared memory, but the bounds inference does not cross the filter_index. Minimal repro to work with:

#include <Halide.h>

using namespace Halide;

int main() {
    Target t = get_target_from_environment();
    Buffer<float> filter_bank(5, 5, 1024, "filter_bank");
    Buffer<float> input(100, 100, "input");

    Var x{"x"}, y{"y"}, fi{"fi"};
    Func filtered{"filtered"};
    RDom filter_rdom{{{0, 5}, {0, 5}}, "filter_rdom"};
    Func reduction{"reduction"};
    Func filter_index{"filter_index"};
    Func filter_bank_wrapper{"filter_bank_wrapper"};
    filter_bank_wrapper(x, y, fi) = filter_bank(x, y, fi);
    filter_index(x, y) = (x + y) % 1024;
    filtered(x, y) = sum(input(x + filter_rdom.x, y + filter_rdom.y) * filter_bank_wrapper(filter_rdom.x, filter_rdom.y, filter_index(x, y)), reduction);

    filter_index.compute_at(filtered, x);
    filter_bank_wrapper.in(reduction)
        .compute_at(reduction, y)
        .vectorize(x);

    Buffer<float> out = filtered.realize({50, 50}, t.with_feature(Target::Feature::NoBoundsQuery).with_feature(Target::Feature::NoRuntime));

    return 0;
}

Produces:

produce filtered {
 let t17 = 0 - (filtered.min.1*filtered.stride.1)
 for (filtered.s0.y.rebased, 0, filtered.extent.1) {
  let t21 = filtered.min.1 + filtered.s0.y.rebased
  let t20 = (filtered.stride.1*t21) + t17
  let t18 = (filtered.min.0 + filtered.min.1) + filtered.s0.y.rebased
  for (filtered.s0.x.rebased, 0, filtered.extent.0) {
   allocate filter_index[int32 * 1]
   produce filter_index {
    filter_index[0] = (filtered.s0.x.rebased + t18) % 1024
   }
   allocate reduction[float32 * 1]
   produce reduction {
    reduction[0] = 0.000000f
    allocate filter_bank_wrapper_in_reduction$0[float32 * 5 * 5 * 1024]
    produce filter_bank_wrapper_in_reduction$0 {
     for (filter_bank_wrapper_in_reduction$0.s0.fi, 0, 1024) {
      let t22 = filter_bank_wrapper_in_reduction$0.s0.fi*5
      for (filter_bank_wrapper_in_reduction$0.s0.y, 0, 5) {
       let t15 = filter_bank_wrapper_in_reduction$0.s0.y + t22
       filter_bank_wrapper_in_reduction$0[ramp(t15*5, 1, 5) aligned(5, 0)] = filter_bank[ramp(t15*5, 1, 5) aligned(5, 0)]
      }
     }
    }
    consume filter_bank_wrapper_in_reduction$0 {
     consume filter_index {
      let t23 = filtered.min.0 + filtered.s0.x.rebased
      for (reduction.s1.filter_rdom$y, 0, 5) {
       let t25 = reduction.s1.filter_rdom$y*5
       let t24 = ((reduction.s1.filter_rdom$y + t21)*100) + t23
       for (reduction.s1.filter_rdom$x, 0, 5) {
        reduction[0] = reduction[0] + (input[reduction.s1.filter_rdom$x + t24]*filter_bank_wrapper_in_reduction$0[((max(min(filter_index[0], 1023), 0)*25) + t25) + reduction.s1.filter_rdom$x])
       }
      }
     }
    }
    free filter_index
    free filter_bank_wrapper_in_reduction$0
   }
   consume reduction {
    filtered[filtered.s0.x.rebased + t20] = reduction[0]
   }
   free reduction
  }
 }
}

So filter_bank_wrapper gets bounds corresponding to the full buffer, instead of to the one filter determined by filter_index. Specifically:

    allocate filter_bank_wrapper_in_reduction$0[float32 * 5 * 5 * 1024]
    produce filter_bank_wrapper_in_reduction$0 {
     for (filter_bank_wrapper_in_reduction$0.s0.fi, 0, 1024) {
mcourteaux commented 1 month ago

The concern is that this might create O(n^2) worth of bounds expressions. Cutting FuncValueBounds off at Func calls keeps the bounds expressions bounded in size. But it's causing problems, so maybe there's something clever we can do.

@abadams Perhaps a scheduling directive? Something like transparent_bounds()/propagate_value_bounds()/exact_bounds()/bounds_passthrough() to not explode the number of bounds expressions by default?