halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.91k stars 1.07k forks source link

Specialize does not remove select conditions as expected #8443

Closed yyuting closed 1 week ago

yyuting commented 1 month ago

When trying to specialize the conditions of a nested select, Halide cannot remove select within the specialization as expected. In the generated stmt file, the specialize condition in the outer if/else branch is rewritten into something different than the select condition.

I'm using Halide 18.0.0 on Mac M2.

Here's a minimal generator example:

// specialize_generator.cpp
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

class SpecializeBugGenerator
    : public Halide::Generator<SpecializeBugGenerator> {
public:
  Input<Buffer<float>> input{"input", 2};
  Input<float> scale_factor_x{"scale_factor_x"};
  Input<float> scale_factor_y{"scale_factor_y"};

  Output<Buffer<float>> output{"output", 2};

  Var x, y;

  void generate() {

    Expr upsample_x = scale_factor_x > 1.0f;
    Expr upsample_y = scale_factor_y > 1.0f;
    Expr upsample = upsample_x && upsample_y;
    Expr downsample = !upsample_x && !upsample_y;

    output(x, y) = select(upsample, input(cast<int>(x / 2), cast<int>(y / 2)),
                          select(downsample, input(x * 2, y * 2), 0.0f));

    output.specialize(upsample).specialize(downsample);
    output.specialize(upsample).specialize(!downsample);
    output.specialize(!upsample).specialize(downsample);
    output.specialize(!upsample).specialize(!downsample);
  }
};

HALIDE_REGISTER_GENERATOR(SpecializeBugGenerator, specialize_bug_generator)

Here's part of the generated stmt.

...
if (1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) {
   if (max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) {
    let t61 = 1.000000f < (float32)scale_factor_y
    let t60 = 1.000000f < (float32)scale_factor_x
    let t63 = 0 - (output.min.1*output.stride.1)
    let t62 = (input.min.1*input.stride.1) + input.min.0
    for (output.s0.v1.rebased, 0, output.extent.1) {
     let t66 = t60 || t61
     let t65 = t60 && t61
     let t64 = output.min.1 + output.s0.v1.rebased
     for (output.s0.v0.rebased, 0, output.extent.0) {
      output[((output.stride.1*t64) + t63) + output.s0.v0.rebased] = select(t65, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t64/2)*input.stride.1) - t62)], select(t66, 0.000000f, input[((((input.stride.1*t64) + output.min.0) + output.s0.v0.rebased)*2) - t62]))
     }
    }
   }
...
mcourteaux commented 2 weeks ago

@abadams Could this simplifier rule be the blame?

https://github.com/halide/Halide/blob/72749c5503300fa577a599938b70e18ddc021fbe/src/Simplify_And.cpp#L75-L75

In the sense that the assumption added by specializing is simplified to this 1 < min(a,b) form, causing the rest of the pattern matcher not to find the actual simplification needed? Potentially because the CSE has already hoised the condition out of the loop?

mcourteaux commented 2 weeks ago

@yyuting Can you provide us with the output of the generator but with HL_DEBUG_CODEGEN=1 as environment variable present?

@yyuting Does .specialize(upsample && !downsample) give any different results? (As opposed to specializing in two steps?)

yyuting commented 2 weeks ago

hi @mcourteaux , specialize with && cannot remove the select condition either.

Generator output with HL_DEBUG_CODEGEN=1:

(base) yutingyang@Yutings-MacBook-Pro-2 halide_specialize_bug % ./specialize_bug_generator -g specialize_bug_generator -e stmt -o . target=host
Generator specialize_bug_generator has base_path ./specialize_bug_generator
compile_multitarget: single target is arm-64-osx-arm_dot_prod-arm_fp16
Applying autoscheduler (NONE) to Generator specialize_bug_generator ...
Failed to prove, but could not find a counter-example:
 (((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v0, (float32)v1)))
Original expression:
(!!(1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) || !((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
 (((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v1, (float32)v0)))
Original expression:
(!!(1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) || !((float32)scale_factor_y > 1.000000f))
Failed to prove, but could not find a counter-example:
 (((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v0, (float32)v1)))
Original expression:
(!(max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || !((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
 (((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v1, (float32)v0)))
Original expression:
(!(max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || !((float32)scale_factor_y > 1.000000f))
Failed to prove, but could not find a counter-example:
 ((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v0))
Original expression:
(!!(min((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || ((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
 ((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v1))
Original expression:
(!!(min((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || ((float32)scale_factor_y > 1.000000f))
Failed to prove, but could not find a counter-example:
 (((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v0, (float32)v1)))
Original expression:
(!!(1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) || !((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
 (((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v1, (float32)v0)))
Original expression:
(!!(1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) || !((float32)scale_factor_y > 1.000000f))
Failed to prove, but could not find a counter-example:
 (((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v0, (float32)v1)))
Original expression:
(!(max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || !((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
 (((float32)v0 <= 1.000000f) || (1.000000f < max((float32)v1, (float32)v0)))
Original expression:
(!(max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) || !((float32)scale_factor_y > 1.000000f))
Failed to prove, but could not find a counter-example:
 ((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v0))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_x))
Failed to prove, but could not find a counter-example:
 ((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v1))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_y))
Failed to prove, but could not find a counter-example:
 ((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v0))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_x))
Failed to prove, but could not find a counter-example:
 ((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v1))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_y))
Failed to prove, but could not find a counter-example:
 ((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v0))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_x))
Failed to prove, but could not find a counter-example:
 ((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v1))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || (1.000000f < (float32)scale_factor_y))
Creating initial loop nests...
Injecting realization of { output }
Inlining input_im
Skipping injecting memoization...
Injecting tracing...
Adding checks for parameters
Computing bounds of each function's value
Clamping unsafe data-dependent accesses
Performing computation bounds inference...
Asserting that all split factors are positive...
Removing extern loops...
Performing sliding window optimization...
Uniquifying variable names...
Simplifying...
Simplifying correlated differences...
Performing allocation bounds inference...
Adding checks for images
Removing code that depends on undef values...
Performing storage folding optimization...
Injecting debug_to_file calls...
Injecting prefetches...
Discarding safe promises...
Dynamically skipping stages...
Forking asynchronous producers...
Destructuring tuple-valued realizations...
Bounding small realizations...
Performing storage flattening...
Adding atomic mutex allocation...
Unpacking buffer arguments...
Skipping rewriting memoized allocations...
Simplifying...
Reduce prefetch dimension...
Simplifying correlated differences...
Bounding constant extent loops...
Unrolling...
Vectorizing...
Detecting vector interleavings...
Partitioning loops to simplify boundary conditions...
Staging strided loads...
Trimming loops to the region over which they do something...
Rebasing loops to zero...
Hoisting loop invariant if statements...
Injecting early frees...
Simplifying correlated differences...
Bounding small allocations...
Simplifying...
Lowering unsafe promises...
Flattening nested ramps...
Removing dead allocations and moving loop invariant code...
Finding intrinsics...
Hoisting prefetches...
Lowering after final simplification:
assert(reinterpret<uint64>((struct halide_buffer_t *)output.buffer) != (uint64)0, halide_error_buffer_argument_is_null("output"))
assert(reinterpret<uint64>((struct halide_buffer_t *)input.buffer) != (uint64)0, halide_error_buffer_argument_is_null("input"))
let input = (void *)_halide_buffer_get_host((struct halide_buffer_t *)input.buffer)
let input.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)input.buffer)
let input.device_dirty = (uint1)_halide_buffer_get_device_dirty((struct halide_buffer_t *)input.buffer)
let input.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)input.buffer)
let input.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)input.buffer, 0)
let input.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)input.buffer, 0)
let input.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)input.buffer, 0)
let input.min.1 = _halide_buffer_get_min((struct halide_buffer_t *)input.buffer, 1)
let input.extent.1 = _halide_buffer_get_extent((struct halide_buffer_t *)input.buffer, 1)
let input.stride.1 = _halide_buffer_get_stride((struct halide_buffer_t *)input.buffer, 1)
let output = (void *)_halide_buffer_get_host((struct halide_buffer_t *)output.buffer)
let output.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)output.buffer)
let output.device_dirty = (uint1)_halide_buffer_get_device_dirty((struct halide_buffer_t *)output.buffer)
let output.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)output.buffer)
let output.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)output.buffer, 0)
let output.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)output.buffer, 0)
let output.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)output.buffer, 0)
let output.min.1 = _halide_buffer_get_min((struct halide_buffer_t *)output.buffer, 1)
let output.extent.1 = _halide_buffer_get_extent((struct halide_buffer_t *)output.buffer, 1)
let output.stride.1 = _halide_buffer_get_stride((struct halide_buffer_t *)output.buffer, 1)
let input.extent.0.required.s = let t102 = (output.extent.0 + output.min.0) in (max((t102 + -1)/2, (t102*2) + -2) - min(output.min.0/2, output.min.0*2))
let input.min.0.required = min(output.min.0/2, output.min.0*2)
let input.extent.1.required.s = let t103 = (output.extent.1 + output.min.1) in (max((t103 + -1)/2, (t103*2) + -2) - min(output.min.1/2, output.min.1*2))
let input.min.1.required = min(output.min.1/2, output.min.1*2)
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)input.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)input.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)input.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 2, (struct halide_dimension_t *)make_struct(input.min.0.required, input.extent.0.required.s + 1, 1, 0, input.min.1.required, input.extent.1.required.s + 1, input.extent.0.required.s + 1, 0), (uint64)0)
}
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)output.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)output.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)output.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 2, (struct halide_dimension_t *)make_struct(output.min.0, output.extent.0, 1, 0, output.min.1, output.extent.1, output.extent.0, 0), (uint64)0)
}
if (!((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)input.buffer) || (uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)output.buffer))) {
 assert(input.type == (uint32)73730, halide_error_bad_type("Input buffer input", input.type, (uint32)73730))
 assert(input.dimensions == 2, halide_error_bad_dimensions("Input buffer input", input.dimensions, 2))
 assert(output.type == (uint32)73730, halide_error_bad_type("Output buffer output", output.type, (uint32)73730))
 assert(output.dimensions == 2, halide_error_bad_dimensions("Output buffer output", output.dimensions, 2))
 assert((input.min.0 <= input.min.0.required) && (((input.extent.0.required.s + input.min.0.required) + 1) <= (input.extent.0 + input.min.0)), halide_error_access_out_of_bounds("Input buffer input", 0, input.min.0.required, input.extent.0.required.s + input.min.0.required, input.min.0, (input.extent.0 + input.min.0) + -1))
 assert(0 <= input.extent.0, halide_error_buffer_extents_negative("Input buffer input", 0, input.extent.0))
 assert((input.min.1 <= input.min.1.required) && (((input.extent.1.required.s + input.min.1.required) + 1) <= (input.extent.1 + input.min.1)), halide_error_access_out_of_bounds("Input buffer input", 1, input.min.1.required, input.extent.1.required.s + input.min.1.required, input.min.1, (input.extent.1 + input.min.1) + -1))
 assert(0 <= input.extent.1, halide_error_buffer_extents_negative("Input buffer input", 1, input.extent.1))
 assert(0 <= output.extent.0, halide_error_buffer_extents_negative("Output buffer output", 0, output.extent.0))
 assert(0 <= output.extent.1, halide_error_buffer_extents_negative("Output buffer output", 1, output.extent.1))
 assert(input.stride.0 == 1, halide_error_constraint_violated("input.stride.0", input.stride.0, "1", 1))
 assert(output.stride.0 == 1, halide_error_constraint_violated("output.stride.0", output.stride.0, "1", 1))
 let input.total_extent.1 = int64(input.extent.1)*int64(input.extent.0)
 let output.total_extent.1 = int64(output.extent.1)*int64(output.extent.0)
 assert((uint64)abs(int64(input.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("input", (uint64)abs(int64(input.extent.0)), (uint64)2147483647))
 assert((uint64)abs(int64(input.extent.1)*int64(input.stride.1)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("input", (uint64)abs(int64(input.extent.1)*int64(input.stride.1)), (uint64)2147483647))
 assert(input.total_extent.1 <= (int64)2147483647, halide_error_buffer_extents_too_large("input", input.total_extent.1, (int64)2147483647))
 assert((uint64)abs(int64(output.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("output", (uint64)abs(int64(output.extent.0)), (uint64)2147483647))
 assert((uint64)abs(int64(output.extent.1)*int64(output.stride.1)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("output", (uint64)abs(int64(output.extent.1)*int64(output.stride.1)), (uint64)2147483647))
 assert(output.total_extent.1 <= (int64)2147483647, halide_error_buffer_extents_too_large("output", output.total_extent.1, (int64)2147483647))
 assert(!input.device_dirty, halide_error_device_dirty_with_no_device_support("Input buffer input"))
 assert(!output.device_dirty, halide_error_device_dirty_with_no_device_support("Output buffer output"))
 assert(input != reinterpret<(void *)>((uint64)0), halide_error_host_is_null("Input buffer input"))
 assert(output != reinterpret<(void *)>((uint64)0), halide_error_host_is_null("Output buffer output"))
 produce output {
  if (1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) {
   if (max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) {
    let t61 = 1.000000f < (float32)scale_factor_y
    let t60 = 1.000000f < (float32)scale_factor_x
    let t63 = 0 - (output.min.1*output.stride.1)
    let t62 = (input.min.1*input.stride.1) + input.min.0
    for (output.s0.v1.rebased, 0, output.extent.1) {
     let t66 = t60 || t61
     let t65 = t60 && t61
     let t64 = output.min.1 + output.s0.v1.rebased
     for (output.s0.v0.rebased, 0, output.extent.0) {
      output[((output.stride.1*t64) + t63) + output.s0.v0.rebased] = select(t65, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t64/2)*input.stride.1) - t62)], select(t66, 0.000000f, input[((((input.stride.1*t64) + output.min.0) + output.s0.v0.rebased)*2) - t62]))
     }
    }
   } else if (1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) {
    let t68 = 1.000000f < (float32)scale_factor_y
    let t67 = 1.000000f < (float32)scale_factor_x
    let t70 = 0 - (output.min.1*output.stride.1)
    let t69 = (input.min.1*input.stride.1) + input.min.0
    for (output.s0.v1.rebased, 0, output.extent.1) {
     let t73 = t67 || t68
     let t72 = t67 && t68
     let t71 = output.min.1 + output.s0.v1.rebased
     for (output.s0.v0.rebased, 0, output.extent.0) {
      output[((output.stride.1*t71) + t70) + output.s0.v0.rebased] = select(t72, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t71/2)*input.stride.1) - t69)], select(t73, 0.000000f, input[((((input.stride.1*t71) + output.min.0) + output.s0.v0.rebased)*2) - t69]))
     }
    }
   } else {
    let t75 = 1.000000f < (float32)scale_factor_y
    let t74 = 1.000000f < (float32)scale_factor_x
    let t77 = 0 - (output.min.1*output.stride.1)
    let t76 = (input.min.1*input.stride.1) + input.min.0
    for (output.s0.v1.rebased, 0, output.extent.1) {
     let t80 = t74 || t75
     let t79 = t74 && t75
     let t78 = output.min.1 + output.s0.v1.rebased
     for (output.s0.v0.rebased, 0, output.extent.0) {
      output[((output.stride.1*t78) + t77) + output.s0.v0.rebased] = select(t79, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t78/2)*input.stride.1) - t76)], select(t80, 0.000000f, input[((((input.stride.1*t78) + output.min.0) + output.s0.v0.rebased)*2) - t76]))
     }
    }
   }
  } else if (max((float32)scale_factor_x, (float32)scale_factor_y) <= 1.000000f) {
   let t82 = 1.000000f < (float32)scale_factor_y
   let t81 = 1.000000f < (float32)scale_factor_x
   let t84 = 0 - (output.min.1*output.stride.1)
   let t83 = (input.min.1*input.stride.1) + input.min.0
   for (output.s0.v1.rebased, 0, output.extent.1) {
    let t87 = t81 || t82
    let t86 = t81 && t82
    let t85 = output.min.1 + output.s0.v1.rebased
    for (output.s0.v0.rebased, 0, output.extent.0) {
     output[((output.stride.1*t85) + t84) + output.s0.v0.rebased] = select(t86, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t85/2)*input.stride.1) - t83)], select(t87, 0.000000f, input[((((input.stride.1*t85) + output.min.0) + output.s0.v0.rebased)*2) - t83]))
    }
   }
  } else if (1.000000f < max((float32)scale_factor_x, (float32)scale_factor_y)) {
   let t89 = 1.000000f < (float32)scale_factor_y
   let t88 = 1.000000f < (float32)scale_factor_x
   let t91 = 0 - (output.min.1*output.stride.1)
   let t90 = (input.min.1*input.stride.1) + input.min.0
   for (output.s0.v1.rebased, 0, output.extent.1) {
    let t94 = t88 || t89
    let t93 = t88 && t89
    let t92 = output.min.1 + output.s0.v1.rebased
    for (output.s0.v0.rebased, 0, output.extent.0) {
     output[((output.stride.1*t92) + t91) + output.s0.v0.rebased] = select(t93, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t92/2)*input.stride.1) - t90)], select(t94, 0.000000f, input[((((input.stride.1*t92) + output.min.0) + output.s0.v0.rebased)*2) - t90]))
    }
   }
  } else {
   let t96 = 1.000000f < (float32)scale_factor_y
   let t95 = 1.000000f < (float32)scale_factor_x
   let t98 = 0 - (output.min.1*output.stride.1)
   let t97 = (input.min.1*input.stride.1) + input.min.0
   for (output.s0.v1.rebased, 0, output.extent.1) {
    let t101 = t95 || t96
    let t100 = t95 && t96
    let t99 = output.min.1 + output.s0.v1.rebased
    for (output.s0.v0.rebased, 0, output.extent.0) {
     output[((output.stride.1*t99) + t98) + output.s0.v0.rebased] = select(t100, input[((output.min.0 + output.s0.v0.rebased)/2) + (((t99/2)*input.stride.1) - t97)], select(t101, 0.000000f, input[((((input.stride.1*t99) + output.min.0) + output.s0.v0.rebased)*2) - t97]))
    }
   }
  }
 }
}

Skipping Hexagon offload...
Skipping GPU offload...
Lowering Parallel Tasks...
Module.compile(): stmt ./specialize_bug_generator.stmt
dir_rmdir: /tmp/xMh3iQ
shoaibkamil commented 2 weeks ago

I simplified this example a bit further into:

#include "Halide.h"
#include <stdio.h>

using namespace Halide;

class SpecializeBugGenerator
    : public Halide::Generator<SpecializeBugGenerator> {
public:
  Input<Buffer<float>> input{"input", 2};
  Input<float> scale_factor_x{"scale_factor_x"};
  Input<float> scale_factor_y{"scale_factor_y"};

  Output<Buffer<float>> output{"output", 2};

  Var x, y;

  void generate() {

    Expr upsample_x = scale_factor_x > 1.0f;
    Expr upsample_y = scale_factor_y > 1.0f;
    Expr upsample = upsample_x && upsample_y;
    Expr downsample = !upsample_x && !upsample_y;

    output(x, y) = select(upsample, input(cast<int>(x / 2), cast<int>(y / 2)),
                          select(downsample, input(x * 2, y * 2), 0.0f));

    output.specialize(upsample /* && !downsample*/ );

  }
};

HALIDE_REGISTER_GENERATOR(SpecializeBugGenerator, specialize_bug_generator)

The issue is essentially what @mcourteaux suspected. When we transform a > 1 && b > 1 into min(a, b) > 1, we can no longer prove the initial fact if necessary, since the latter form doesn't imply the former. Running with HL_DEBUG_CODEGEN=2:

Failed to prove, but could not find a counter-example:
 ((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v0))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || ((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
 ((min((float32)v0, (float32)v1) <= 1.000000f) || (1.000000f < (float32)v1))
Original expression:
(!(1.000000f < min((float32)scale_factor_x, (float32)scale_factor_y)) || ((float32)scale_factor_y > 1.000000f))

~Even if we remove the relevant rule, we end up with a problem here:~

Failed to prove, but could not find a counter-example:
 (!((1.000000f < (float32)v0) && (1.000000f < (float32)v1)) || (1.000000f < (float32)v0))
Original expression:
(!((1.000000f < (float32)scale_factor_x) && (1.000000f < (float32)scale_factor_y)) || ((float32)scale_factor_x > 1.000000f))
Failed to prove, but could not find a counter-example:
 (!((1.000000f < (float32)v0) && (1.000000f < (float32)v1)) || (1.000000f < (float32)v1))
Original expression:
(!((1.000000f < (float32)scale_factor_x) && (1.000000f < (float32)scale_factor_y)) || ((float32)scale_factor_y > 1.000000f))

Spoke too soon-- removing that rule does result in specialize doing the right thing for the modified even simpler version above.

abadams commented 2 weeks ago

The rule is good, so learn_true needs to understand min(...) > constant, or specializations should be resolved before the simplifier gets a turn at it.

alexreinking commented 1 week ago

specializations should be resolved before the simplifier gets a turn at it.

Yeah, I think we should guarantee that syntactic matches between select and specializations get resolved.