Open abadams opened 3 months ago
Any plans with this? It keeps hitting me. I'm trying to load filter weights of a particular filter into GPUShared memory, but the bounds inference does not cross the filter_index. Minimal repro to work with:
#include <Halide.h>
using namespace Halide;
int main() {
Target t = get_target_from_environment();
Buffer<float> filter_bank(5, 5, 1024, "filter_bank");
Buffer<float> input(100, 100, "input");
Var x{"x"}, y{"y"}, fi{"fi"};
Func filtered{"filtered"};
RDom filter_rdom{{{0, 5}, {0, 5}}, "filter_rdom"};
Func reduction{"reduction"};
Func filter_index{"filter_index"};
Func filter_bank_wrapper{"filter_bank_wrapper"};
filter_bank_wrapper(x, y, fi) = filter_bank(x, y, fi);
filter_index(x, y) = (x + y) % 1024;
filtered(x, y) = sum(input(x + filter_rdom.x, y + filter_rdom.y) * filter_bank_wrapper(filter_rdom.x, filter_rdom.y, filter_index(x, y)), reduction);
filter_index.compute_at(filtered, x);
filter_bank_wrapper.in(reduction)
.compute_at(reduction, y)
.vectorize(x);
Buffer<float> out = filtered.realize({50, 50}, t.with_feature(Target::Feature::NoBoundsQuery).with_feature(Target::Feature::NoRuntime));
return 0;
}
Produces:
produce filtered {
let t17 = 0 - (filtered.min.1*filtered.stride.1)
for (filtered.s0.y.rebased, 0, filtered.extent.1) {
let t21 = filtered.min.1 + filtered.s0.y.rebased
let t20 = (filtered.stride.1*t21) + t17
let t18 = (filtered.min.0 + filtered.min.1) + filtered.s0.y.rebased
for (filtered.s0.x.rebased, 0, filtered.extent.0) {
allocate filter_index[int32 * 1]
produce filter_index {
filter_index[0] = (filtered.s0.x.rebased + t18) % 1024
}
allocate reduction[float32 * 1]
produce reduction {
reduction[0] = 0.000000f
allocate filter_bank_wrapper_in_reduction$0[float32 * 5 * 5 * 1024]
produce filter_bank_wrapper_in_reduction$0 {
for (filter_bank_wrapper_in_reduction$0.s0.fi, 0, 1024) {
let t22 = filter_bank_wrapper_in_reduction$0.s0.fi*5
for (filter_bank_wrapper_in_reduction$0.s0.y, 0, 5) {
let t15 = filter_bank_wrapper_in_reduction$0.s0.y + t22
filter_bank_wrapper_in_reduction$0[ramp(t15*5, 1, 5) aligned(5, 0)] = filter_bank[ramp(t15*5, 1, 5) aligned(5, 0)]
}
}
}
consume filter_bank_wrapper_in_reduction$0 {
consume filter_index {
let t23 = filtered.min.0 + filtered.s0.x.rebased
for (reduction.s1.filter_rdom$y, 0, 5) {
let t25 = reduction.s1.filter_rdom$y*5
let t24 = ((reduction.s1.filter_rdom$y + t21)*100) + t23
for (reduction.s1.filter_rdom$x, 0, 5) {
reduction[0] = reduction[0] + (input[reduction.s1.filter_rdom$x + t24]*filter_bank_wrapper_in_reduction$0[((max(min(filter_index[0], 1023), 0)*25) + t25) + reduction.s1.filter_rdom$x])
}
}
}
}
free filter_index
free filter_bank_wrapper_in_reduction$0
}
consume reduction {
filtered[filtered.s0.x.rebased + t20] = reduction[0]
}
free reduction
}
}
}
So filter_bank_wrapper
gets bounds corresponding to the full buffer, instead of to the one filter determined by filter_index
. Specifically:
allocate filter_bank_wrapper_in_reduction$0[float32 * 5 * 5 * 1024]
produce filter_bank_wrapper_in_reduction$0 {
for (filter_bank_wrapper_in_reduction$0.s0.fi, 0, 1024) {
The concern is that this might create O(n^2) worth of bounds expressions. Cutting FuncValueBounds off at Func calls keeps the bounds expressions bounded in size. But it's causing problems, so maybe there's something clever we can do.
@abadams Perhaps a scheduling directive? Something like transparent_bounds()
/propagate_value_bounds()
/exact_bounds()
/bounds_passthrough()
to not explode the number of bounds expressions by default?
Currently bounds inference acts differently if you inline a Func or not. Some things are only legal if inlined (e.g. #8261).
The cause is that the bounds on the value of a Func are precomputed, and do not depend on the LHS vars of that Func. So while a particular use of a Func may have bounds, because bounded args are passed to it, bounds inference fails to recognize that. We could fix this by using a symbolic interval here: https://github.com/halide/Halide/blob/main/src/Bounds.cpp#L3359 instead of an infinite one, and then binding those symbols with Lets wherever we use FuncValueBounds.
The concern is that this might create O(n^2) worth of bounds expressions. Cutting FuncValueBounds off at Func calls keeps the bounds expressions bounded in size. But it's causing problems, so maybe there's something clever we can do.