Closed yyuting closed 1 week ago
@abadams Could this simplifier rule be the blame?
In the sense that the assumption added by specializing is simplified to this 1 < min(a,b)
form, causing the rest of the pattern matcher not to find the actual simplification needed? Potentially because the CSE has already hoised the condition out of the loop?
@yyuting Can you provide us with the output of the generator but with HL_DEBUG_CODEGEN=1
as environment variable present?
@yyuting Does .specialize(upsample && !downsample)
give any different results? (As opposed to specializing in two steps?)
hi @mcourteaux , specialize with &&
cannot remove the select
condition either.
Generator output with HL_DEBUG_CODEGEN=1
:
(base) yutingyang@Yutings-MacBook-Pro-2 halide_specialize_bug % ./specialize_bug_generator -g specialize_bug_generator -e stmt -o . target=host
Generator specialize_bug_generator has base_path ./specialize_bug_generator
compile_multitarget: single target is arm-64-osx-arm_dot_prod-arm_fp16
Applying autoscheduler (NONE) to Generator specialize_bug_generator ...
Failed to prove, but could not find a counter-example:
(((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v0, (float32)v1)))
Original expression:
(!!(1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) || !((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
(((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v1, (float32)v0)))
Original expression:
(!!(1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) || !((float32)scale_factor_y > 1.000000f))
Failed to prove, but could not find a counter-example:
(((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v0, (float32)v1)))
Original expression:
(!(max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || !((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
(((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v1, (float32)v0)))
Original expression:
(!(max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || !((float32)scale_factor_y > 1.000000f))
Failed to prove, but could not find a counter-example:
((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v0))
Original expression:
(!!(min((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || ((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v1))
Original expression:
(!!(min((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || ((float32)scale_factor_y > 1.000000f))
Failed to prove, but could not find a counter-example:
(((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v0, (float32)v1)))
Original expression:
(!!(1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) || !((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
(((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v1, (float32)v0)))
Original expression:
(!!(1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) || !((float32)scale_factor_y > 1.000000f))
Failed to prove, but could not find a counter-example:
(((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v0, (float32)v1)))
Original expression:
(!(max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || !((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
(((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v1, (float32)v0)))
Original expression:
(!(max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || !((float32)scale_factor_y > 1.000000f))
Failed to prove, but could not find a counter-example:
((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v0))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_x))
Failed to prove, but could not find a counter-example:
((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v1))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_y))
Failed to prove, but could not find a counter-example:
((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v0))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_x))
Failed to prove, but could not find a counter-example:
((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v1))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_y))
Failed to prove, but could not find a counter-example:
((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v0))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_x))
Failed to prove, but could not find a counter-example:
((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v1))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_y))
Creating initial loop nests...
Injecting realization of { output }
Inlining input_im
Skipping injecting memoization...
Injecting tracing...
Adding checks for parameters
Computing bounds of each function's value
Clamping unsafe data-dependent accesses
Performing computation bounds inference...
Asserting that all split factors are positive...
Removing extern loops...
Performing sliding window optimization...
Uniquifying variable names...
Simplifying...
Simplifying correlated differences...
Performing allocation bounds inference...
Adding checks for images
Removing code that depends on undef values...
Performing storage folding optimization...
Injecting debug_to_file calls...
Injecting prefetches...
Discarding safe promises...
Dynamically skipping stages...
Forking asynchronous producers...
Destructuring tuple-valued realizations...
Bounding small realizations...
Performing storage flattening...
Adding atomic mutex allocation...
Unpacking buffer arguments...
Skipping rewriting memoized allocations...
Simplifying...
Reduce prefetch dimension...
Simplifying correlated differences...
Bounding constant extent loops...
Unrolling...
Vectorizing...
Detecting vector interleavings...
Partitioning loops to simplify boundary conditions...
Staging strided loads...
Trimming loops to the region over which they do something...
Rebasing loops to zero...
Hoisting loop invariant if statements...
Injecting early frees...
Simplifying correlated differences...
Bounding small allocations...
Simplifying...
Lowering unsafe promises...
Flattening nested ramps...
Removing dead allocations and moving loop invariant code...
Finding intrinsics...
Hoisting prefetches...
Lowering after final simplification:
assert(reinterpret<uint64>((struct halide_buffer_t *)output.buffer) != (uint64)0, halide_error_buffer_argument_is_null("output"))
assert(reinterpret<uint64>((struct halide_buffer_t *)input.buffer) != (uint64)0, halide_error_buffer_argument_is_null("input"))
let input = (void *)_halide_buffer_get_host((struct halide_buffer_t *)input.buffer)
let input.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)input.buffer)
let input.device_dirty = (uint1)_halide_buffer_get_device_dirty((struct halide_buffer_t *)input.buffer)
let input.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)input.buffer)
let input.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)input.buffer, 0)
let input.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)input.buffer, 0)
let input.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)input.buffer, 0)
let input.min.1 = _halide_buffer_get_min((struct halide_buffer_t *)input.buffer, 1)
let input.extent.1 = _halide_buffer_get_extent((struct halide_buffer_t *)input.buffer, 1)
let input.stride.1 = _halide_buffer_get_stride((struct halide_buffer_t *)input.buffer, 1)
let output = (void *)_halide_buffer_get_host((struct halide_buffer_t *)output.buffer)
let output.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)output.buffer)
let output.device_dirty = (uint1)_halide_buffer_get_device_dirty((struct halide_buffer_t *)output.buffer)
let output.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)output.buffer)
let output.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)output.buffer, 0)
let output.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)output.buffer, 0)
let output.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)output.buffer, 0)
let output.min.1 = _halide_buffer_get_min((struct halide_buffer_t *)output.buffer, 1)
let output.extent.1 = _halide_buffer_get_extent((struct halide_buffer_t *)output.buffer, 1)
let output.stride.1 = _halide_buffer_get_stride((struct halide_buffer_t *)output.buffer, 1)
let input.extent.0.required.s = let t102 = (output.extent.0 + output.min.0) in (max((t102 + -1)/2, (t102*2) + -2) - min(output.min.0/2, output.min.0*2))
let input.min.0.required = min(output.min.0/2, output.min.0*2)
let input.extent.1.required.s = let t103 = (output.extent.1 + output.min.1) in (max((t103 + -1)/2, (t103*2) + -2) - min(output.min.1/2, output.min.1*2))
let input.min.1.required = min(output.min.1/2, output.min.1*2)
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)input.buffer)) {
(struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)input.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)input.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 2, (struct halide_dimension_t *)make_struct(input.min.0.required, input.extent.0.required.s + 1, 1, 0, input.min.1.required, input.extent.1.required.s + 1, input.extent.0.required.s + 1, 0), (uint64)0)
}
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)output.buffer)) {
(struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)output.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)output.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 2, (struct halide_dimension_t *)make_struct(output.min.0, output.extent.0, 1, 0, output.min.1, output.extent.1, output.extent.0, 0), (uint64)0)
}
if (!((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)input.buffer) || (uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)output.buffer))) {
assert(input.type == (uint32)73730, halide_error_bad_type("Input buffer input", input.type, (uint32)73730))
assert(input.dimensions == 2, halide_error_bad_dimensions("Input buffer input", input.dimensions, 2))
assert(output.type == (uint32)73730, halide_error_bad_type("Output buffer output", output.type, (uint32)73730))
assert(output.dimensions == 2, halide_error_bad_dimensions("Output buffer output", output.dimensions, 2))
assert((input.min.0 <= input.min.0.required) && (((input.extent.0.required.s + input.min.0.required) + 1) <= (input.extent.0 + input.min.0)), halide_error_access_out_of_bounds("Input buffer input", 0, input.min.0.required, input.extent.0.required.s + input.min.0.required, input.min.0, (input.extent.0 + input.min.0) + -1))
assert(0 <= input.extent.0, halide_error_buffer_extents_negative("Input buffer input", 0, input.extent.0))
assert((input.min.1 <= input.min.1.required) && (((input.extent.1.required.s + input.min.1.required) + 1) <= (input.extent.1 + input.min.1)), halide_error_access_out_of_bounds("Input buffer input", 1, input.min.1.required, input.extent.1.required.s + input.min.1.required, input.min.1, (input.extent.1 + input.min.1) + -1))
assert(0 <= input.extent.1, halide_error_buffer_extents_negative("Input buffer input", 1, input.extent.1))
assert(0 <= output.extent.0, halide_error_buffer_extents_negative("Output buffer output", 0, output.extent.0))
assert(0 <= output.extent.1, halide_error_buffer_extents_negative("Output buffer output", 1, output.extent.1))
assert(input.stride.0 == 1, halide_error_constraint_violated("input.stride.0", input.stride.0, "1", 1))
assert(output.stride.0 == 1, halide_error_constraint_violated("output.stride.0", output.stride.0, "1", 1))
let input.total_extent.1 = int64(input.extent.1)*int64(input.extent.0)
let output.total_extent.1 = int64(output.extent.1)*int64(output.extent.0)
assert((uint64)abs(int64(input.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("input", (uint64)abs(int64(input.extent.0)), (uint64)2147483647))
assert((uint64)abs(int64(input.extent.1)*int64(input.stride.1)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("input", (uint64)abs(int64(input.extent.1)*int64(input.stride.1)), (uint64)2147483647))
assert(input.total_extent.1 <= (int64)2147483647, halide_error_buffer_extents_too_large("input", input.total_extent.1, (int64)2147483647))
assert((uint64)abs(int64(output.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("output", (uint64)abs(int64(output.extent.0)), (uint64)2147483647))
assert((uint64)abs(int64(output.extent.1)*int64(output.stride.1)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("output", (uint64)abs(int64(output.extent.1)*int64(output.stride.1)), (uint64)2147483647))
assert(output.total_extent.1 <= (int64)2147483647, halide_error_buffer_extents_too_large("output", output.total_extent.1, (int64)2147483647))
assert(!input.device_dirty, halide_error_device_dirty_with_no_device_support("Input buffer input"))
assert(!output.device_dirty, halide_error_device_dirty_with_no_device_support("Output buffer output"))
assert(input != reinterpret<(void *)>((uint64)0), halide_error_host_is_null("Input buffer input"))
assert(output != reinterpret<(void *)>((uint64)0), halide_error_host_is_null("Output buffer output"))
produce output {
if (1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) {
if (max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) {
let t61 = 1.000000f < (float32)scale_factor_y
let t60 = 1.000000f < (float32)scale_factor_x
let t63 = 0 - (output.min.1*output.stride.1)
let t62 = (input.min.1*input.stride.1) + input.min.0
for (output.s0.v1.rebased, 0, output.extent.1) {
let t66 = t60 || t61
let t65 = t60 && t61
let t64 = output.min.1 + output.s0.v1.rebased
for (output.s0.v0.rebased, 0, output.extent.0) {
output[((output.stride.1*t64) + t63) + output.s0.v0.rebased] = select(t65, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t64/2)*input.stride.1) - t62)], select(t66, 0.000000f, input[((((input.stride.1*t64) + output.min.0) + output.s0.v0.rebased)*2) - t62]))
}
}
} else if (1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) {
let t68 = 1.000000f < (float32)scale_factor_y
let t67 = 1.000000f < (float32)scale_factor_x
let t70 = 0 - (output.min.1*output.stride.1)
let t69 = (input.min.1*input.stride.1) + input.min.0
for (output.s0.v1.rebased, 0, output.extent.1) {
let t73 = t67 || t68
let t72 = t67 && t68
let t71 = output.min.1 + output.s0.v1.rebased
for (output.s0.v0.rebased, 0, output.extent.0) {
output[((output.stride.1*t71) + t70) + output.s0.v0.rebased] = select(t72, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t71/2)*input.stride.1) - t69)], select(t73, 0.000000f, input[((((input.stride.1*t71) + output.min.0) + output.s0.v0.rebased)*2) - t69]))
}
}
} else {
let t75 = 1.000000f < (float32)scale_factor_y
let t74 = 1.000000f < (float32)scale_factor_x
let t77 = 0 - (output.min.1*output.stride.1)
let t76 = (input.min.1*input.stride.1) + input.min.0
for (output.s0.v1.rebased, 0, output.extent.1) {
let t80 = t74 || t75
let t79 = t74 && t75
let t78 = output.min.1 + output.s0.v1.rebased
for (output.s0.v0.rebased, 0, output.extent.0) {
output[((output.stride.1*t78) + t77) + output.s0.v0.rebased] = select(t79, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t78/2)*input.stride.1) - t76)], select(t80, 0.000000f, input[((((input.stride.1*t78) + output.min.0) + output.s0.v0.rebased)*2) - t76]))
}
}
}
} else if (max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) {
let t82 = 1.000000f < (float32)scale_factor_y
let t81 = 1.000000f < (float32)scale_factor_x
let t84 = 0 - (output.min.1*output.stride.1)
let t83 = (input.min.1*input.stride.1) + input.min.0
for (output.s0.v1.rebased, 0, output.extent.1) {
let t87 = t81 || t82
let t86 = t81 && t82
let t85 = output.min.1 + output.s0.v1.rebased
for (output.s0.v0.rebased, 0, output.extent.0) {
output[((output.stride.1*t85) + t84) + output.s0.v0.rebased] = select(t86, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t85/2)*input.stride.1) - t83)], select(t87, 0.000000f, input[((((input.stride.1*t85) + output.min.0) + output.s0.v0.rebased)*2) - t83]))
}
}
} else if (1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) {
let t89 = 1.000000f < (float32)scale_factor_y
let t88 = 1.000000f < (float32)scale_factor_x
let t91 = 0 - (output.min.1*output.stride.1)
let t90 = (input.min.1*input.stride.1) + input.min.0
for (output.s0.v1.rebased, 0, output.extent.1) {
let t94 = t88 || t89
let t93 = t88 && t89
let t92 = output.min.1 + output.s0.v1.rebased
for (output.s0.v0.rebased, 0, output.extent.0) {
output[((output.stride.1*t92) + t91) + output.s0.v0.rebased] = select(t93, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t92/2)*input.stride.1) - t90)], select(t94, 0.000000f, input[((((input.stride.1*t92) + output.min.0) + output.s0.v0.rebased)*2) - t90]))
}
}
} else {
let t96 = 1.000000f < (float32)scale_factor_y
let t95 = 1.000000f < (float32)scale_factor_x
let t98 = 0 - (output.min.1*output.stride.1)
let t97 = (input.min.1*input.stride.1) + input.min.0
for (output.s0.v1.rebased, 0, output.extent.1) {
let t101 = t95 || t96
let t100 = t95 && t96
let t99 = output.min.1 + output.s0.v1.rebased
for (output.s0.v0.rebased, 0, output.extent.0) {
output[((output.stride.1*t99) + t98) + output.s0.v0.rebased] = select(t100, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t99/2)*input.stride.1) - t97)], select(t101, 0.000000f, input[((((input.stride.1*t99) + output.min.0) + output.s0.v0.rebased)*2) - t97]))
}
}
}
}
}
Skipping Hexagon offload...
Skipping GPU offload...
Lowering Parallel Tasks...
Module.compile(): stmt ./specialize_bug_generator.stmt
dir_rmdir: /tmp/xMh3iQ
I simplified this example a bit further into:
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
class SpecializeBugGenerator
: public Halide::Generator<SpecializeBugGenerator> {
public:
Input<Buffer<float>> input{"input", 2};
Input<float> scale_factor_x{"scale_factor_x"};
Input<float> scale_factor_y{"scale_factor_y"};
Output<Buffer<float>> output{"output", 2};
Var x, y;
void generate() {
Expr upsample_x = scale_factor_x > 1.0f;
Expr upsample_y = scale_factor_y > 1.0f;
Expr upsample = upsample_x && upsample_y;
Expr downsample = !upsample_x && !upsample_y;
output(x, y) = select(upsample, input(cast<int>(x / 2), cast<int>(y / 2)),
select(downsample, input(x * 2, y * 2), 0.0f));
output.specialize(upsample /* && !downsample*/ );
}
};
HALIDE_REGISTER_GENERATOR(SpecializeBugGenerator, specialize_bug_generator)
The issue is essentially what @mcourteaux suspected. When we transform a > 1 && b > 1
into min(a, b) > 1
, we can no longer prove the initial fact if necessary, since the latter form doesn't imply the former. Running with HL_DEBUG_CODEGEN=2
:
Failed to prove, but could not find a counter-example:
((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v0))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || ((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v1))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || ((float32)scale_factor_y > 1.000000f))
~Even if we remove the relevant rule, we end up with a problem here:~
Failed to prove, but could not find a counter-example:
(!((1.000000f < (float32)v0) && (1.000000f < (float32)v1)) || (1.000000f < (float32)v0))
Original expression:
(!((1.000000f < (float32)scale_factor_x) && (1.000000f < (float32)scale_factor_y)) || ((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
(!((1.000000f < (float32)v0) && (1.000000f < (float32)v1)) || (1.000000f < (float32)v1))
Original expression:
(!((1.000000f < (float32)scale_factor_x) && (1.000000f < (float32)scale_factor_y)) || ((float32)scale_factor_y > 1.000000f))
Spoke too soon-- removing that rule does result in specialize doing the right thing for the modified even simpler version above.
The rule is good, so learn_true needs to understand min(...) > constant, or specializations should be resolved before the simplifier gets a turn at it.
specializations should be resolved before the simplifier gets a turn at it.
Yeah, I think we should guarantee that syntactic matches between select
and specializations get resolved.
When trying to specialize the conditions of a nested
select
, Halide cannot remove select within the specialization as expected. In the generated stmt file, the specialize condition in the outer if/else branch is rewritten into something different than the select condition.I'm using Halide 18.0.0 on Mac M2.
Here's a minimal generator example:
Here's part of the generated stmt.